From 8118d84f773068ddda236507e7ccc2be5c6a728c Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Sat, 27 Sep 2025 19:03:48 +0600 Subject: [PATCH] [CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836) * Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout Add comments with dropout implementation details Fix performance regression of fwd+dropout * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox; * "scalarize" seed and offset, they may come either from kernel args or from device memory (presumably loaded with vector loads). These changes help the compiler to procude more optimal code and reduce register spilling. Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding Use code based on BlockDropout in BlockDropoutBwd Refactor BlockDropout (fwd) Implement BlockDropout (fwd) for WMMA Originally BlockDropout only supported 32x32 tiles (IsWG32 = true), this version supports 16x16 tiles. If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly to BlockDropoutBwd. Implement BlockDropoutBwd for WMMA Remove MakeRandValLds* functions unused in BlockDropoutBwd Remove unused Run overload from BlockDropoutBwd * Fix regression with philox seed and offset when they exceed 32-bit int __builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset are 64-bit so they get truncated. * Add F32 MFMA warp gemms * Support f32 in fwd FMHA * Implement transpose_vectors for 4-byte types (float) * Fix unexpected implicit f32->uint32 cast in buffer_store<4> __builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint). mbuf_t types in other buffer_store<> are changed for consistency. * Support F32 in bwd FMHA hdim = 256 is disabled for now because it uses too much memory on gfx90a * Support Headdim = 48 (divisible by 16) in fwd * Add fp32-specific receipts (800 and 801) * Tune fwd tiles * Tune bwd tiles * Use small tiles only for small seqlen_q * Fix after rebasing * Fix selection of a fallback tile based on bm0 The assumption that the largest bm0 == 128 is not always true for current fp32 tiles. * Remove constraints and adjust filtering for fp32 Custom constraints are no longer needed because now the smallest tile is selected automtically based on seqlen_q. Filters related to qr_async_trload disabled valid fp32 tiles. * Add fp32 tests * Make splitkv and appendkv compile for fp32 only There are no instances yet, but API still must compile when only fp32 is requested. * Remove unimportant f32 instances * Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS * Replace magic numbers with a constant, improve comments for dropout * Update changelog * Fix condition that dq_acc must be set to zero when mask is used The change was introduced in #2799 * Replace warp_uniform with recently added amd_wave_read_first_lane * Add hdim = 96 and 192 to fwd [ROCm/composable_kernel commit: 1edd250115bc3edd987b7d038f61290a0460d0a3] --- CHANGELOG.md | 3 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 4 +- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 7 + .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 33 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 74 +- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 10 + .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 14 + .../codegen/ops/fmha_pagedkv_prefill.py | 6 + example/ck_tile/01_fmha/example_fmha_bwd.cpp | 8 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 8 +- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 14 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 22 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 12 +- .../core/arch/amd_buffer_addressing.hpp | 12 +- include/ck_tile/core/utility/philox_rand.hpp | 16 +- .../core/utility/transpose_vectors.hpp | 10 +- .../reference_batched_dropout_randval.hpp | 12 +- .../ck_tile/ops/fmha/block/block_dropout.hpp | 713 ++++++++---------- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 3 + ...gedkv_pipeline_qr_ks_vs_default_policy.hpp | 25 +- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 25 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 126 +++- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 24 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 18 + .../warp/warp_gemm_attribute_mfma_impl.hpp | 131 +++- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 5 + test/CMakeLists.txt | 2 + test/ck_tile/fmha/CMakeLists.txt | 8 + test/ck_tile/fmha/test_fmha_bwd_fp32.cpp | 20 + test/ck_tile/fmha/test_fmha_fwd.inc | 6 + test/ck_tile/fmha/test_fmha_fwd_fp32.cpp | 39 + 31 files changed, 922 insertions(+), 488 deletions(-) create mode 100644 test/ck_tile/fmha/test_fmha_bwd_fp32.cpp create mode 100644 test/ck_tile/fmha/test_fmha_fwd_fp32.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index fe1e7ef345..438320d907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. -* Added tensor-wise quantization for CK_TILE GEMM +* Added support for f32 to FMHA (fwd/bwd). +* Added tensor-wise quantization for CK_TILE GEMM. ### Optimized diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 802c9e51d7..81d34484a5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", @@ -12,6 +13,7 @@ FWD_DTYPE_MAP = { } BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16" } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 0d8f366d8a..e2f69fa49a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bd6a9044e9..7319ef7ea1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -370,7 +370,14 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -865,6 +872,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True if not t.convert_dq_kernel.disabled: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index da0c9ca931..f898d5f7b2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -25,6 +25,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -164,7 +165,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; @@ -249,9 +250,8 @@ class FmhaFwdApiTrait: else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + def seqtune(self, max_bm0 : int) -> str: + if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' else: return f'a.seqlen_q <= {self.bm0}' @@ -386,6 +386,7 @@ class FmhaFwdApiPool: per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -393,7 +394,7 @@ class FmhaFwdApiPool: F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -534,7 +535,20 @@ class KernelComponentFactory: # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -572,7 +586,13 @@ class KernelComponentFactory: # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: @@ -626,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': @@ -635,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue + if dtype != 'fp32': + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -710,6 +733,31 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0ebeaddf9c..38491b56c4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @dataclass @@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op cond &= pipeline.F_vlayout == 'row' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index cee1505486..281357ef1e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -768,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -834,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index df6b422981..3624b7b387 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index e0e1fba668..73b3c1e619 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[]) "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -159,7 +159,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == bwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 79fda6d564..c27a5ce1ae 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -67,7 +67,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -227,7 +227,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == fwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index d861b351d4..b6f2c8ca30 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -35,6 +35,14 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-4; + double atol = 1e-4; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { @@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -776,7 +786,7 @@ bwd_result fmha_bwd_run(mode_enum mode, // non-deterministic kernels use atomic add to write dq // Some block may be skipped with causal mask and dq are not set to zeros // In these cases thus we need to zero out it first - if(!deterministic || mask.type == mask_enum::no_mask) + if(!deterministic || mask.type != mask_enum::no_mask) dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f5dd42a6bd..761def6d6a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,10 @@ #include #include +struct FmhaFwdFp32 +{ +}; + struct FmhaFwdFp16 { }; @@ -48,6 +52,22 @@ struct FmhaFwdFp8Fp32 template struct FmhaFwdTypeConfig; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = float; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + template <> struct FmhaFwdTypeConfig { diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index e58e040f19..0703af71e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(std::string /*init_method*/) { @@ -180,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7bc5ca5df8..de3427c33d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -470,7 +470,7 @@ struct buffer_store<16> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; + using mbuf_t = uint32x4_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b128( @@ -496,7 +496,7 @@ struct buffer_store<8> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; + using mbuf_t = uint32x2_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b64( @@ -522,7 +522,7 @@ struct buffer_store<4> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = uint32_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b32( @@ -548,7 +548,7 @@ struct buffer_store<2> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); - using mbuf_t = short; + using mbuf_t = uint16_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b16( @@ -573,8 +573,8 @@ struct buffer_store<1> index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { - static_assert(sizeof(T) == 4); - using mbuf_t = float; + static_assert(sizeof(T) == 1); + using mbuf_t = uint8_t; #if HAS_RAW_BUFFER_BUILTINS index_t s_offset = i_offset; __builtin_amdgcn_raw_buffer_store_b8( diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp index 87abf5cc18..52b1489543 100644 --- a/include/ck_tile/core/utility/philox_rand.hpp +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,7 +55,8 @@ class philox CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out, const unsigned long long subsequence, - const index_t start_idx) const + const index_t idx0, + const index_t idx1) const { uint4 tmp_ph; tmp_ph = get_philox_4x32(subsequence); @@ -66,13 +67,12 @@ class philox tmp[2] = tmp_ph.z; tmp[3] = tmp_ph.w; uint32_t* out_tmp = reinterpret_cast(&out[0]); - out_tmp[0] = tmp[start_idx]; - out_tmp[1] = tmp[start_idx + 2]; + out_tmp[0] = tmp[idx0]; + out_tmp[1] = tmp[idx1]; } - CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out, - const unsigned long long subsequence, - const index_t start_idx) const + CK_TILE_HOST_DEVICE void + get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const { uint4 tmp_ph; tmp_ph = get_philox_4x32(subsequence); @@ -83,7 +83,7 @@ class philox tmp[2] = tmp_ph.z; tmp[3] = tmp_ph.w; uint32_t* out_tmp = reinterpret_cast(&out[0]); - out_tmp[0] = tmp[start_idx]; + out_tmp[0] = tmp[idx]; } private: diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp index 497fd3b948..f0d7dae706 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,7 +34,13 @@ struct transpose_vectors constexpr auto I3 = number<3>{}; constexpr auto I4 = number<4>{}; - if constexpr(sizeof(S) == 2) + if constexpr(sizeof(S) == 4) + { + static_for<0, NY, 1>{}([&](auto iy) { + static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; }); + }); + } + else if constexpr(sizeof(S) == 2) { static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); diff --git a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp index 2a02adaee3..ec6c6009b7 100644 --- a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp +++ b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp @@ -33,18 +33,22 @@ reference_batched_dropout_randval(HostTensor& randval_b_m // With SFactor = 2 it becomes: // C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8) // C j: (lane % 32) + // See ck_tile/ops/fmha/block/block_dropout.hpp for more details. - constexpr index_t max_warp_size = 64; - constexpr index_t warp_gemm_mn = 32; + // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values + constexpr index_t philox_per_tile = 64; + constexpr index_t warp_gemm_mn = 32; const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn); const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn); auto f = [&](index_t i_h, index_t row, index_t col) { uint2 rowcol = make_uint2(row, col); - for(index_t lane = 0; lane < max_warp_size; lane++) + for(index_t lane = 0; lane < philox_per_tile; lane++) { - philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane); + const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile; + const index_t ph_offset = lane; + philox ph(drop_seed, ph_head_offset + ph_offset); uint8_t random_uint8_t[16]; ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index e036402e16..8abdd54cd9 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -1,17 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { +// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and +// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random +// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host +// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp). +// +// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of +// random numbers (ph_subsequence). +// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and +// ph_offset). +// This means that subsequences are non-overlapping, reproducible and independent of mask or window. +// +// There are 3 modes (all produce the same results): +// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates +// the entire 32x32 tile (64 * 16 = 32 * 32). +// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4 +// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock > +// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions +// are needed for generating a 32x32 tile. +// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2 +// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp * +// WG::kM one warp can generate two 16x16 tiles. + +namespace detail { +// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values +constexpr index_t philox_per_tile = 64; +} // namespace detail + struct NullBlockDropout { template - __host__ __device__ static constexpr auto + CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { @@ -32,7 +59,9 @@ struct BlockDropout float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_) - : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), rp_undrop(rp_undrop_), p_undrop_in_uint8_t(p_undrop_in_uint8_t_), is_store_randval(is_store_randval_) @@ -46,11 +75,15 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); auto randval_dram_window = [&]() { @@ -78,12 +111,17 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = WG::kN; - constexpr index_t kN1 = 8; - constexpr index_t kN0 = kNPerStep / kN1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( ck_tile::make_tuple(number{}, number{}, number{}), @@ -107,33 +145,35 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; + // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution, + // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< sequence<>, - tuple, sequence>, + tuple, sequence>, tuple>, - tuple>, + tuple>, sequence<1, 2>, - sequence<0, 0>>{}; + sequence<1, 0>>{}; // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. - constexpr auto randval_block_inner_part_dstr_encoding = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - } - else - { - return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - } - }(); + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; constexpr auto randval_block_part_dstr_encode = detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, @@ -147,11 +187,13 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< @@ -181,14 +223,16 @@ struct BlockDropout { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; // randval tile in LDS auto randval_lds = make_tensor_view( @@ -200,42 +244,100 @@ struct BlockDropout // register distribute auto randval_dist_generated = make_static_distributed_tensor(MakeRandValTileDistribution()); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - auto randval_lds_read_window = + const auto randval_lds_read_window = make_tile_window(randval_lds_window.get_bottom_tensor_view(), randval_lds_window.get_window_lengths(), randval_lds_window.get_window_origin(), MakeRandValLdsShuffleTileDistribution()); - const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers taken + // from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // Transpose randval using LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + const auto randval = load_tile(randval_lds_read_window); + block_sync_lds(); + return randval; + }; + if(is_store_randval) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, - reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); + const auto randval = generate_randval(i_m0, i_n0); // save to Global const auto randval_store = cast_tile(randval); store_tile(randval_dram_window, randval_store); @@ -244,37 +346,21 @@ struct BlockDropout move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); }); move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - }; + } static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx0 = + tile_distributed_index()>{}; constexpr auto p_idx1 = - tile_distributed_index{}; + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t @@ -286,12 +372,15 @@ struct BlockDropout }); } - ck_tile::philox ph; + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; const float rp_undrop; const uint8_t p_undrop_in_uint8_t; const bool is_store_randval; }; +// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be +// replaced with NullBlockDropout. This requires changes in xformers and other libs. template struct BlockDropoutBwd; @@ -301,8 +390,8 @@ struct BlockDropoutBwd static constexpr bool IsDropout = false; static constexpr bool IsStoreRandval = IsStoreRandval_; - template - __host__ __device__ static constexpr auto + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { @@ -316,10 +405,7 @@ struct BlockDropoutBwd template struct BlockDropoutBwd { - static constexpr bool IsDropout = true; - // true: 32*32 warp gemm - // false: 16*16 warp gemm - static constexpr bool IsWG32 = IsWG32_; + static constexpr bool IsDropout = true; static constexpr bool IsStoreRandval = IsStoreRandval_; CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, @@ -329,38 +415,30 @@ struct BlockDropoutBwd unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_) - : ph(seed, - offset + (i_batch * nheads + i_head) * get_warp_size() + - (IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))), + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), rp_undrop(rp_undrop_), p_undrop_in_uint8_t(p_undrop_in_uint8_t_) { } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using BlockGemmShape = remove_cvref_t; - using WG = remove_cvref_t())>; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); - constexpr index_t kMPerStep = [&]() { - if constexpr(MBwdWG16MultiIterCheck) - { - return MWarp * WG::kM * 2; - } - else - { - return MWarp * WG::kM; - } - }(); - constexpr index_t kNPerStep = NWarp * WG::kN; + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); auto randval_dram_window = [&]() { @@ -384,85 +462,39 @@ struct BlockDropoutBwd } template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = WG::kN; - constexpr index_t kN1 = 8; - constexpr index_t kN0 = kNPerStep / kN1; - - constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( - ck_tile::make_tuple(number{}, number{}, number{}), - ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto randval_lds_block_desc = transform_tensor_descriptor( - randval_lds_block_desc_0, - ck_tile::make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(ck_tile::make_tuple(number{}, number{}))), - ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), - ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); - - return randval_lds_block_desc; - } - - template CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); - - constexpr index_t MIterPerWarp = [&]() { - if constexpr(MBwdWG16MultiIterCheck) - { - return 2; - } - else - { - return 1; - } - }(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; constexpr index_t NIterPerWarp = 1; constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< sequence<>, - tuple, sequence>, + tuple, sequence>, tuple>, - tuple>, + tuple>, sequence<1, 2>, - sequence<0, 0>>{}; + sequence<1, 0>>{}; - // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. - // except headdim256. - constexpr auto randval_block_inner_part_dstr_encoding = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - if constexpr(IsWG32) - return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - else - return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{}; - } - else - { - if constexpr(IsWG32) - return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; - else - return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{}; - } - }(); + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + static_assert( + std::is_same_v, + typename WG::CWarpDstrEncoding>); constexpr auto randval_block_part_dstr_encode = detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, @@ -471,129 +503,6 @@ struct BlockDropoutBwd return make_static_tile_distribution(randval_block_part_dstr_encode); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = 1; - constexpr index_t NIterPerWarp = 1; - - constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto randval_block_part_dstr_encode = - detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, - typename WG::CWarpDstrEncoding{}); - - return make_static_tile_distribution(randval_block_part_dstr_encode); - } - - template - CK_TILE_HOST_DEVICE void Run(void* randval_ptr, - const index_t start_m0_idx, - const index_t start_n0_idx, - PComputeWindow& p_compute, - RandValDramWindow& randval_dram_window) const - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - // randval tile in LDS - auto randval_lds = make_tensor_view( - reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); - - auto randval_lds_window = make_tile_window( - randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); - - // register distribute - auto randval_dist_generated = - make_static_distributed_tensor(MakeRandValTileDistribution()); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - - auto randval_lds_read_window = - make_tile_window(randval_lds_window.get_bottom_tensor_view(), - randval_lds_window.get_window_lengths(), - randval_lds_window.get_window_origin(), - MakeRandValLdsShuffleTileDistribution()); - - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); - int block_col_start = (start_n0_idx / WG::kN) + i_n0; - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // save to LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - // read from LDS to register - auto randval = load_tile(randval_lds_read_window); - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = tile_distributed_index{}; - constexpr auto p_idx1 = - tile_distributed_index{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] * rp_undrop - : PComputeDataType(0); - }); - }); - // save to Global - if constexpr(IsStoreRandval) - { - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } - } - template { constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16); - constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16); - constexpr index_t kMPerStep = [&]() { - if constexpr(MBwdWG16MultiIterCheck) + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) { - return MWarp * WG::kM * 2; + // Generate the whole 32x32 tile at once (each tile consists of random numbers + // taken from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); } else { - return MWarp * WG::kM; + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } } - }(); - constexpr index_t kNPerStep = NWarp * WG::kN; - // register distribute - auto randval = make_static_distributed_tensor( - MakeRandValTileDistribution()); - if constexpr(IsWG32) - static_assert(randval.kThreadElementSpaceSize == 16); - else - static_assert(randval.kThreadElementSpaceSize == 4 || - randval.kThreadElementSpaceSize == 8); + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + return randval_dist_generated; + }; static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - int block_row_start, block_col_start; - if constexpr(IsWG32) - { - block_row_start = (start_m0_idx / WG::kM) + i_m0; - block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); - } - else - { - block_row_start = start_m0_idx / 32 + i_m0; - block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2; - } - uint2 rowcol = make_uint2(block_row_start, block_col_start); - - // generate random number - uint8_t* random_uint8_t_; - if constexpr(MBwdWG16SingleIterCheck) - { - uint8_t random_uint8_t[4]; - // m0t0 ~m0t15/m0t32~m0t47: 0 - // m0t16~m0t31/m0t48~m0t63: 1 - // m1t0 ~m1t15/m1t32~m1t47: 2 - // m1t16~m1t31/m1t48~m1t63: 3 - const index_t start_idx = - ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1); - ph.get_random_4x8( - random_uint8_t, reinterpret_cast(rowcol), start_idx); - random_uint8_t_ = random_uint8_t; - } - else if constexpr(MBwdWG16MultiIterCheck) - { - uint8_t random_uint8_t[8]; - // t0 ~t15/t32~t47: 0 - // t16~t31/t48~t63: 1 - const index_t start_idx = (get_lane_id() >> 4) & 1; - ph.get_random_8x8( - random_uint8_t, reinterpret_cast(rowcol), start_idx); - random_uint8_t_ = random_uint8_t; - } - else - { - uint8_t random_uint8_t[16]; - ph.get_random_16x8(random_uint8_t, - reinterpret_cast(rowcol)); - random_uint8_t_ = random_uint8_t; - } - + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities, negative sign is used to + // distinguish such values ​​later in bwd pipeline. constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - int i_random_idx = 0; sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - randval(r_idx) = random_uint8_t_[i_random_idx++]; - constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + constexpr auto p_idx0 = + tile_distributed_index(), + idx0.impl_.template at<1>(), + idx0.impl_.template at<2>()>{}; constexpr auto p_idx1 = tile_distributed_index{}; constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t @@ -717,7 +645,8 @@ struct BlockDropoutBwd } } - ck_tile::philox ph; + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; const float rp_undrop; const uint8_t p_undrop_in_uint8_t; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b2b00a07e4..980dfb06ae 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -82,6 +82,7 @@ struct FmhaBwdDQDKDVKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on @@ -1187,6 +1188,7 @@ struct FmhaBwdOGradDotOKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on @@ -1443,6 +1445,7 @@ struct FmhaBwdConvertQGradKernel // clang-format off template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; // clang-format on diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp index 9c348495ff..f7ee88f906 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -32,12 +32,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -49,6 +64,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 67ab548dab..050eb48384 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -264,12 +264,27 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -281,6 +296,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index dccb41ba44..9dba3c85d5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -73,12 +73,27 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -90,6 +105,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -201,7 +218,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto q_lds_block_desc = transform_tensor_descriptor( @@ -228,14 +245,29 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 16); + + return WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -247,6 +279,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) @@ -258,6 +292,8 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 32); + // TODO: hard coded here. Otherwise, it may incorrect result constexpr index_t swizzle_factor = 4; return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< @@ -507,7 +543,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, number{}, number{}), make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto k_lds_block_desc = transform_tensor_descriptor( @@ -806,15 +842,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, @@ -824,7 +859,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, // N0 K2 N2 sequence<0, 2, 2>>{}); } - else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) + else if constexpr(get_warp_size() % (K2 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); @@ -863,13 +898,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, tuple, sequence>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, // N1, N2 K0 tuple, sequence<2, 0>>, - sequence<1, 2>, + sequence<1, 2>, // N0 K1 sequence<0, 1>>{}); + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + static_assert(kKPerBlock % 16 == 0); + constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16; + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N2_m = get_warp_size() / K1_m; + constexpr index_t N0_m = kNPerBlock / (N2_m * N1); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, // K0 N0 K2 + sequence<0, 0, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); + return dstr_m; + } } } @@ -897,14 +959,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, tuple, sequence>, @@ -913,7 +975,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, // N0 K2 <-> N2 sequence<0, 2, 2>>{}); } - else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) + else if constexpr(get_warp_size() % (K2 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 41a744ea91..ca82519e72 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -7,20 +7,22 @@ namespace ck_tile { -static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len) +template +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() { - if(len == 96) + if constexpr(Headdim == 48) + return 48; + else if constexpr(Headdim == 96) return 128; - if(len == 160) + else if constexpr(Headdim == 160) return 256; - if(len == 192) + else if constexpr(Headdim == 192) return 192; - - // only length of 96, 160 and power-of-two is supported - if(!(len & (len - 1))) - return len; - - return 0; + else if constexpr(is_power_of_two_integer(Headdim)) + return Headdim; + else + static_assert(Headdim == 0, + "only Headdim of 48, 96, 160, 192 and power-of-two is supported"); }; template (); // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f83bbc2a18..21f21e1aa0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -12,6 +12,24 @@ namespace ck_tile { +// fp32 + +using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl< + WarpGemmAttributeMfma>>; + +template +using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl, + 4, + AttrNumAccess>>; + +template +using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution = + WarpGemmImpl, + 4, + AttrNumAccess>>; + // fp16 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 11a8416fb2..7528760439 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,6 +61,135 @@ enum class WGAttrCtlEnum DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \ } +// F32 +template +struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 1; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 2; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 1; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // V_MFMA_F32_16x16x32_BF16 template struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 5eedd42b04..924f7c4a54 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -23,6 +23,11 @@ template struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index df3a03cca8..292bc41a0b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,8 +38,10 @@ set(REGRESSION_TESTS test_conv_tensor_rearrange test_gemm_mx test_ck_tile_batched_transpose + test_ck_tile_fmha_bwd_fp32 test_ck_tile_fmha_bwd_bf16 test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_fp32 test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index b17d682560..8e5cce4c0b 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -6,12 +6,18 @@ endif() set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +add_gtest_executable(test_ck_tile_fmha_bwd_fp32 test_fmha_bwd_fp32.cpp) +target_link_libraries(test_ck_tile_fmha_bwd_fp32 PRIVATE ${FMHA_BWD_INSTANCES}) + add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp) target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES}) add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp) target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES}) +add_gtest_executable(test_ck_tile_fmha_fwd_fp32 test_fmha_fwd_fp32.cpp) +target_link_libraries(test_ck_tile_fmha_fwd_fp32 PRIVATE ${FMHA_FWD_INSTANCES}) + add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp) target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES}) @@ -23,8 +29,10 @@ target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES}) add_custom_target(test_ck_tile_fmha DEPENDS + test_ck_tile_fmha_bwd_fp32 test_ck_tile_fmha_bwd_bf16 test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_fp32 test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp new file mode 100644 index 0000000000..d409d0dd30 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" + +#include "gtest/gtest.h" + +using DataTypeConfig = FmhaBwdFp32; + +using ::testing::Values; +using ::testing::ValuesIn; + +const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +constexpr std::string init_method = "uf"; + +#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 9497122594..ccca5cf969 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -515,6 +515,8 @@ class PagedKV : public TestWithParam, { }; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PagedKV); + INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PagedKV, Combine(SplitKVHDimValues, @@ -580,6 +582,8 @@ class SplitKV : public TestWithParam, { }; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SplitKV); + INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, SplitKV, Combine(SplitKVHDimValues, @@ -662,6 +666,8 @@ INSTANTIATE_TEST_SUITE_P( std::tuple{2, 3, 1, 264, 265, "1"}, std::tuple{4, 4, 2, 71, 64, "1"}))); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV); + TEST_P(AppendKV, Test) { auto [hdims, diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp new file mode 100644 index 0000000000..00f1eb0629 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd_fp32.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#include +#include + +using ::testing::Values; + +using DataTypeConfig = FmhaFwdFp32; + +const auto HDimValues = Values(std::tuple{32, -1}, + std::tuple{48, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, -1}, + std::tuple{256, -1}); + +const auto SplitKVHDimValues = Values(); + +const auto AppendKVHDimValues = Values(); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +const auto IsVRowmajorValues = Values(true); + +const bool squant = false; +const std::string init_method = "uf"; +const bool def_lse = true; +const bool def_is_v_rowmajor = true; + +int adjust_seqlen(int seqlen) { return seqlen; } + +#include "test_fmha_fwd.inc"