diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fce19483c..a3c77bbc50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,12 +19,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for f32 to FMHA (fwd/bwd). * Added tensor-wise quantization for CK_TILE GEMM. * Added support for batched contraction kernel. +* Added WMMA (gfx12) support for FMHA. * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE ### Changed * Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) +* Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures. +* FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time. ## Composable Kernel 1.1.0 for ROCm 7.1.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index f58dff8e15..049da5637f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -346,7 +346,6 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) @@ -617,11 +616,11 @@ endif() if(NOT MIOPEN_REQ_LIBS_ONLY) # make check runs the entire set of examples and tests - add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL) # make smoke runs the tests and examples that runs within 30 seconds on gfx90a - add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") + add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST" USES_TERMINAL) # make regression runs the tests and examples that runs for more 30 seconds on gfx90a - add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") + add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST" USES_TERMINAL) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 98ca17a571..af3e211aef 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -42,6 +42,34 @@ def checkForPattern(pattern, log) { return [found: false, matchedLine: "", context: ""] } +// Scan build logs and send notifications +def sendFailureNotifications() { + // Get the build log. + def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true) + // Check for patterns in the log. + def foundPatterns = [] + for (patternMap in failurePatterns) { + def result = checkForPattern(patternMap.pattern, buildLog) + if (result.found) { + foundPatterns.add([ + description: patternMap.description, + matchedLine: result.matchedLine, + context: result.context + ]) + } + } + // Send a notification for each matched failure pattern. + for (patternMap in foundPatterns) { + withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) { + sh ''' + curl -X POST "${WEBHOOK_URL}" \ + -H 'Content-Type: application/json' \ + -d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}' + ''' + } + } +} + class Version { int major, minor, patch @Override @@ -1557,6 +1585,25 @@ pipeline { cleanWs() } } + stage("Run CK_TILE_FMHA Tests on gfx1201") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx1201") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx12-generic && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx1201 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } stage("Run TILE_ENGINE_GEMM Tests") @@ -1863,7 +1910,7 @@ pipeline { } agent{ label 'miopen && (gfx1101 || gfx1100)' } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx11-generic" \ @@ -1884,7 +1931,7 @@ pipeline { } agent{ label rocmnode("gfx1201") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx12-generic" \ @@ -1948,30 +1995,7 @@ pipeline { failure { node(rocmnode("nogpu")) { script { - // Get the build log. - def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true) - // Check for patterns in the log. - def foundPatterns = [] - for (patternMap in failurePatterns) { - def result = checkForPattern(patternMap.pattern, buildLog) - if (result.found) { - foundPatterns.add([ - description: patternMap.description, - matchedLine: result.matchedLine, - context: result.context - ]) - } - } - // Send a notification for each matched failure pattern. - for (patternMap in foundPatterns) { - withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) { - sh ''' - curl -X POST "${WEBHOOK_URL}" \ - -H 'Content-Type: application/json' \ - -d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}' - ''' - } - } + sendFailureNotifications() } } } diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index b8ca26193d..ce914b92af 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,8 +1,8 @@ set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) -# Currently only gfx9 archs are supported by FMHA -list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +# Currently only gfx9 and gfx12 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") if(NOT INST_TARGETS) - message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() endif() @@ -12,6 +12,7 @@ set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(BUILD_TESTING) # Build instances of all APIs for tests + message(DEBUG "Enabling all FWD APIs of CK Tile FMHA for because testing is enabled") set(FMHA_FWD_ENABLE_APIS "all") endif() if(FMHA_FWD_ENABLE_APIS STREQUAL "all") @@ -36,15 +37,19 @@ file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS # re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") +list(JOIN INST_TARGETS , FMHA_TARGETS_ARG) + string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py + --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} --optdim 32,64,128,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py + --targets ${FMHA_TARGETS_ARG} --api bwd --receipt 3 --optdim 32,64,96,128,256 @@ -67,7 +72,7 @@ execute_process( RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") + message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.") endif() execute_process( @@ -76,7 +81,7 @@ execute_process( RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") + message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.") endif() # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -89,6 +94,7 @@ add_custom_command( COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile FMHA FWD kernels" ) add_custom_command( @@ -96,6 +102,7 @@ add_custom_command( COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile FMHA BWD kernels" ) set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") diff --git a/example/ck_tile/01_fmha/codegen/arch.py b/example/ck_tile/01_fmha/codegen/arch.py new file mode 100644 index 0000000000..1bfc78d3cd --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/arch.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +from dataclasses import dataclass, field +from typing import Any, List, Callable + + +@dataclass(frozen=True) +class ArchTrait: + name: str + preprocessor_check: str = field(default=None) + device_name_check: str = field(default=None) + tag: str = field(default=None) + filename_suffix: str = field(default=None) + + def __post_init__(self): + if self.preprocessor_check is None: + object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)") + if self.device_name_check is None: + object.__setattr__( + self, + "device_name_check", + f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0', + ) + if self.tag is None: + object.__setattr__(self, "tag", f"ck_tile::{self.name}_t") + if self.filename_suffix is None: + object.__setattr__(self, "filename_suffix", f"_{self.name}") + + +def get_factories_for_targets( + targets: List[str], get_factory: Callable[[str], Any] +) -> List[Any]: + factories = dict() + for target in targets: + factory = get_factory(target) + factories[factory.arch.name] = factory + # Place more specific architectures first + factories = sorted( + list(factories.values()), key=lambda f: len(f.arch.name), reverse=True + ) + return factories diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 2e3f96e4a6..74db4e084c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -21,6 +21,7 @@ from codegen.cpp_symbol_map import ( BOOL_MAP, PIPELINE_ENUM_MAP, ) +from codegen.utils import update_file DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} @@ -441,7 +442,7 @@ class FmhaFwdApiPool: ) if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_dtypes += " (void)t ; (void)s ; (void)a;" + per_dtypes += " (void)t; (void)s; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes) @@ -720,15 +721,20 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: @@ -737,7 +743,12 @@ def write_blobs( def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index d007b4caa3..7238749bfc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -3,13 +3,14 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass import fnmatch import itertools +from collections import OrderedDict +from dataclasses import dataclass from pathlib import Path from typing import List, Tuple, Dict, Literal, Any -from collections import defaultdict +from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( get_mask_check_map, @@ -22,16 +23,20 @@ from codegen.cpp_symbol_map import ( BWD_DTYPE_MAP, BOOL_MAP, ) -from codegen.utils import update_file +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "fmha_bwd.hpp" """ FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile:: @@ -132,10 +137,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_maxq}, {F_bn0}>; -#include - template <> -float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; if(s.log_level_ > 0) @@ -144,67 +147,68 @@ float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} template <> -int fmha_bwd_dq_dk_dv_maxq_() +int fmha_bwd_dq_dk_dv_maxq_() {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; return k_::kMaxSeqLenQ; }} template <> -std::string fmha_bwd_dq_dk_dv_get_name_() +std::string fmha_bwd_dq_dk_dv_get_name_() {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; return k_::GetName(); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp" FMHA_BWD_API = """ #include -template +template float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if constexpr (!std::is_same_v) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} ); }} else {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} ); }} }} template <> float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); float r = -1; {F_dispatch} return r; @@ -212,23 +216,22 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf """ -def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_=0) -> str: +def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str: lines = [ - f"{'if' if if_ == 0 else 'else if'}({F_cond})", + f"{if_(if_i)}({F_cond})", "{", - *[" " + line for line in F_body.split("\n") if line.strip() != ""], + indent(F_body), "}", ] - return "\n".join(" " * indent + line for line in lines) + "\n" + return "\n".join(lines) + "\n" -FMHA_BWD_API_INNER_DISPATCH = """ -{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ +FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; - r = fmha_bwd_>(s, a); + r = fmha_bwd_, {F_arch.tag}>(s, a); return r; }} """ @@ -283,6 +286,7 @@ class FmhaBwdDQDKDVTileSize: @dataclass(frozen=True) class FmhaBwdDQDKDVKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -302,6 +306,7 @@ class FmhaBwdDQDKDVKernel: def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=BWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_tile.F_bm0, @@ -399,43 +404,97 @@ class FmhaBwdDQDKDVKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" -# TODO: design a more practical way to do it -# this is current supported tile size. -def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if dtype == "fp32" and tr_load == "f": - return [ - # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - ] # fmt: skip - elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f": - return [ - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - ] # fmt: skip - elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t": - return [ - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), - FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), - # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), - FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), - ] # fmt: skip - else: +class KernelComponentFactoryBase: + pass + + +class KernelComponentFactoryGfx9(KernelComponentFactoryBase): + arch = ArchTrait( + "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" + ) + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if tr_load == "t": + return [] + if dtype in ["fp32"]: + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4,bhdq,bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] # fmt: skip + if dtype in ["fp16", "bf16"]: + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] # fmt: skip return [] +class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): + arch = ArchTrait("gfx950") + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load) + if dtype in ["fp16", "bf16"] and tr_load == "t": + results.extend([ + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), + FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), + # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ]) # fmt: skip + return results + + +class KernelComponentFactoryGfx12(KernelComponentFactoryBase): + arch = ArchTrait("gfx12") + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if tr_load == "t": + return [] + if dtype in ["fp16", "bf16"]: + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + ] # fmt: skip + return [] + + +def get_factory(target: str): + # Place more specific architectures first + + if target.startswith("gfx950"): + return KernelComponentFactoryGfx950 + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + if target.startswith("gfx12"): + return KernelComponentFactoryGfx12 + + raise Exception(f"Unsupported device target {target}") + + FMHA_BWD_DOT_DO_O_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_bwd_dot_do_o_trait_{F_idx} = @@ -445,7 +504,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = M0 = */ 64, + /* BlockSize = M0 = */ {F_bm0}, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; @@ -459,10 +518,8 @@ using fmha_bwd_dot_do_o_kernel_{F_idx} = using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; -#include - template <> -float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; if(s.log_level_ > 0) @@ -471,34 +528,38 @@ float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> -void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} template <> -std::string fmha_bwd_dot_do_o_get_name_() +std::string fmha_bwd_dot_do_o_get_name_() {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; return k_::GetName(); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ @dataclass(frozen=True) class FmhaBwdOGradDotOKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type + F_bm0: int # tile size along q seqlen (block size) F_spad: str # true/false F_dvpad: str # F_mode: str # value from MODE_MAP @@ -508,8 +569,10 @@ class FmhaBwdOGradDotOKernel: def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=BWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_bm0, F_spad=BOOL_MAP[self.F_spad], F_dvpad=BOOL_MAP[self.F_dvpad], F_mode=MODE_MAP[self.F_mode], @@ -529,7 +592,7 @@ class FmhaBwdOGradDotOKernel: return n pn = pad_name() - n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}" if pn != "": n += f"_{pn}" else: @@ -538,10 +601,14 @@ class FmhaBwdOGradDotOKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_bwd_convert_dq_trait_{F_idx} = @@ -573,10 +640,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_deterministic}, {F_bn0}>; -#include - template <> -float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; if(s.log_level_ > 0) @@ -585,32 +650,34 @@ float fmha_bwd_convert_dq_(const ck_tile::stream_confi const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> -void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} template <> -std::string fmha_bwd_convert_dq_get_name_() +std::string fmha_bwd_convert_dq_get_name_() {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; return k_::GetName(); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ @dataclass(frozen=True) class FmhaBwdConvertQGradKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -627,6 +694,7 @@ class FmhaBwdConvertQGradKernel: def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=BWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_bm0, @@ -664,11 +732,12 @@ class FmhaBwdConvertQGradKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" @dataclass(frozen=True) class FmhaBwdApiTrait: + arch: ArchTrait idx: int # this is not a tunable, but a counter to differentiate symbol # sync with fmha_bwd_traits<>, to generate fallback calls hdim: int @@ -705,10 +774,10 @@ class FmhaBwdApiTrait: @property def scheck(self) -> str: if self.mode == "group": - return "true" # always support + return "true /*spad1d is always true in group mode*/" elif self.spad1d == "t": - return f"a.seqlen_q % {M0_1D} != 0" - else: # self.spad1d == 'f' + return f"true /*a.seqlen_q % {M0_1D} != 0*/" + else: # self.spad1d == "f" return f"a.seqlen_q % {M0_1D} == 0" @property @@ -725,10 +794,17 @@ class FmhaBwdApiTrait: else: return f"a.hdim_v % {self.dvpad} == 0" + @property + def max_seq_q_cond(self) -> str: + if self.tile.max_seq_q != 0: + return f" && (a.seqlen_q <= {self.tile.max_seq_q})" + else: + return "" + @property def extra_cond(self) -> str: if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: - return "&& (a.seqlen_k <= 256)" + return " && (a.seqlen_k <= 256)" else: return "" @@ -745,9 +821,11 @@ class FmhaBwdApiTrait: F_dvpad = "t" if self.dvpad else "f" return FmhaBwdOGradDotOKernel( + F_arch=self.arch, F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=M0_1D, F_spad=self.spad1d, F_dvpad=F_dvpad, F_mode=self.mode, @@ -757,6 +835,7 @@ class FmhaBwdApiTrait: @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel( + F_arch=self.arch, F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, @@ -782,6 +861,7 @@ class FmhaBwdApiTrait: F_dpad = "t" if self.dpad else "f" return FmhaBwdConvertQGradKernel( + F_arch=self.arch, F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, @@ -798,28 +878,25 @@ class FmhaBwdApiTrait: class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict( - lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - ) - + self.dq_dk_dv_pool = OrderedDict() self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait: FmhaBwdApiTrait) -> None: - # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][ - trait.hdim - ].append(copy.copy(trait)) + hdim = trait.hdim + ts = ( + self.dq_dk_dv_pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) - @staticmethod - def if_(i: int) -> str: - return "if" if i == 0 else "else if" - - def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: + def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str: inners = "" - i = 0 - for trait in traits: + for i_trait, trait in enumerate(traits): inners += FMHA_BWD_API_INNER_DISPATCH.format( - F_if=self.if_(i), + F_if=if_(i_trait), + F_arch=trait.arch, F_mode=MODE_MAP[trait.mode], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], @@ -840,27 +917,18 @@ class FmhaBwdApiPool: F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], - F_bn0=trait.tile.F_bn0, + F_max_seq_q_cond=trait.max_seq_q_cond, F_cond_extra=trait.extra_cond, + F_bn0=trait.tile.F_bn0, F_convert_dq_bn0=trait.convert_dq_bn0, ) - i += 1 return inners @staticmethod - def trload_sort_key(tf): - return 0 if tf == "t" else 1 # sort 't' before 'f' - - @staticmethod - def max_seq_q_sort_key(max_seq_q): - return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end - - @staticmethod - def max_seq_q_cond(max_seq_q: int) -> str: - if max_seq_q == 0: - return "true /* no seqlen_q limit */" - else: - return f"a.seqlen_q <= {max_seq_q}" + def max_seq_q_sort_key(trait): + return ( + trait.tile.max_seq_q if trait.tile.max_seq_q != 0 else 1000000 + ) # sort 0 to the end @staticmethod def dtype_cond(dtype: str) -> str: @@ -872,42 +940,34 @@ class FmhaBwdApiPool: @property def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true /* no trload requirement */"} - per_tr_load = "" - for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): - per_max_seq_q = "" - for max_seq_q in sorted( - self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key - ): - per_dtypes = "" - for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): - per_hdim_case = "" - for k, hdim in enumerate( - self.dq_dk_dv_pool[tr_load][max_seq_q][dtype] - ): - traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] - inners = self._api_innders(traits) - per_hdim_case += FMHA_BWD_API_COND_STATEMENT( - if_=k, F_cond=self.hdim_cond(hdim), F_body=inners - ) - per_dtypes += FMHA_BWD_API_COND_STATEMENT( - if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case + per_arch = "" + for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()): + per_dtypes = "" + for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + per_hdim_case = "" + for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key) + inners = self._api_inners(traits) + per_hdim_case += FMHA_BWD_API_COND_STATEMENT( + if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners ) - per_max_seq_q += FMHA_BWD_API_COND_STATEMENT( - F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes + per_dtypes += FMHA_BWD_API_COND_STATEMENT( + if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case ) - per_tr_load += FMHA_BWD_API_COND_STATEMENT( - F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4 + per_arch += FMHA_BWD_API_COND_STATEMENT( + if_i=i_arch, F_cond=arch.device_name_check, F_body=per_dtypes ) - if not per_tr_load: + if not per_arch: # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a; (void)has_load_tr;" - result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch=per_tr_load) + per_arch = "(void)t; (void)s; (void)a;" + result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format( + F_dispatch=indent(per_arch) + ) return result.replace("\n\n", "\n") def get_bwd_blobs( - filter_list: str, receipt, mask_impl, optdim_list + targets: List[str], filter_list: str, receipt, mask_impl, optdim_list ) -> Tuple[ FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], @@ -922,14 +982,19 @@ def get_bwd_blobs( filter_convert_dq = filters[1] filter_dq_dk_dv = filters[2] + factories = get_factories_for_targets(targets, get_factory) + # use dict as ordered set - gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} - gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} - gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = OrderedDict() + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = OrderedDict() + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = OrderedDict() api_pool = FmhaBwdApiPool(mask_impl) - for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): - tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) + for factory, dtype, tr_load in itertools.product( + factories, BWD_DTYPE_MAP.keys(), ["t", "f"] + ): + tiles: Any = factory.get_dq_dk_dv_tiles(dtype, tr_load) + spad1d_options = ["f", "t"] dpad_options = itertools.product(*([[0, 8, 1]] * 2)) tf = ["t", "f"] for tile, mode, mask, bias, dbias, dropout, spad1d, ( @@ -942,7 +1007,7 @@ def get_bwd_blobs( BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), - tf, + spad1d_options, dpad_options, tf, ): @@ -958,6 +1023,8 @@ def get_bwd_blobs( continue if "wg32" in dropout: continue + if spad1d == "f" and tile.max_seq_q != 0 and tile.max_seq_q < M0_1D: + continue # max_seq_q < M0_1D requires padding if tr_load == "t": # tr_load can only work with 8 pad if dpad != dvpad or dpad == 1: @@ -970,6 +1037,7 @@ def get_bwd_blobs( if hdim not in optdim_list: continue t = FmhaBwdApiTrait( + arch=factory.arch, idx=0, hdim=hdim, dtype=dtype, @@ -989,10 +1057,10 @@ def get_bwd_blobs( if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue - if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): - continue if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue # Flash attention integration if receipt == 2: @@ -1076,10 +1144,15 @@ def get_bwd_blobs( def write_blobs( - output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + filter_list: str, + receipt, + optdim_list, + mask_impl, ) -> None: api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( - filter_list, receipt, mask_impl, optdim_list + targets, filter_list, receipt, mask_impl, optdim_list ) update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) for k in kernels_dot_do_o: @@ -1091,10 +1164,15 @@ def write_blobs( def list_blobs( - file_path: Path, filter_list: str, receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + filter_list: str, + receipt, + optdim_list, + mask_impl, ) -> None: _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( - filter_list, receipt, mask_impl, optdim_list + targets, filter_list, receipt, mask_impl, optdim_list ) with file_path.open("a") as f: for k in kernels_dot_do_o: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2cec9c713a..2acc467410 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1,15 +1,17 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass, field import fnmatch import itertools import os +from collections import OrderedDict +from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple +from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, @@ -23,7 +25,7 @@ from codegen.cpp_symbol_map import ( BIAS_MAP, get_mask_map, ) -from codegen.utils import update_file +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} @@ -31,13 +33,17 @@ DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" """ FMHA_FWD_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -99,10 +105,8 @@ using fmha_kernel_{F_idx} = using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; -#include - template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -110,8 +114,10 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" @@ -148,13 +154,13 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq }} }} // namespace -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate unsigned num_cus; - if (!get_num_cus(num_cus)) {{ + if(!get_num_cus(num_cus)) {{ return r; }} @@ -162,32 +168,33 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); {F_dispatch} return r; }} """ -FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ {F_dtype_case} - }} +}} """ -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{ {F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} +}} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; - return fmha_fwd_(s, a); - }} +FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} +}} +""" + +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + return fmha_fwd_(s, a); +}} """ @@ -207,6 +214,7 @@ class CppConstraint: @dataclass class FmhaFwdApiTrait: + arch: ArchTrait pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls hdim: str @@ -413,40 +421,35 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: def __init__(self, mask_impl): - self.pool = dict() + self.pool = OrderedDict() self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) + ts = ( + self.pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) @property def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - - per_tr_load = str() - for tr_load in ["t", "f"]: + per_arch = str() + for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): + for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): per_hdim_case = str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits = [ - t - for t in self.pool[dtype][(hdim, hdim_v)] - if tr_load == t.tr_load - ] - max_bm0 = max((t.bm0 for t in traits), default=0) + for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( + pool_by_dtype.items() + ): + max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, + for i_trait, trait in enumerate(pool_by_hdim): + inners += FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_(i_trait), + F_arch=arch, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], @@ -479,23 +482,24 @@ class FmhaFwdApiPool: F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), + F_hdim=hdim, + F_hdim_v=hdim_v, + F_inner_dispatch=indent(inners), ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + per_dtypes += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) ) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( - F_if="if", - F_trload_cond=tr_load_cond_map[tr_load], - F_dtype_case=per_dtypes, + per_arch += FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes), ) - if not per_tr_load: + if not per_arch: # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + per_arch = "(void)t; (void)s; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) @dataclass @@ -533,6 +537,7 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -545,6 +550,7 @@ class FmhaFwdKernel: def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_tile.F_bm0, @@ -596,10 +602,11 @@ class FmhaFwdKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( + arch=self.F_arch, pipeline_tag=self.F_pipeline.tag, hdim=str(self.F_hdim), dtype=self.F_dtype, @@ -627,12 +634,16 @@ class FmhaFwdKernel: ) -class KernelComponentFactory: +class KernelComponentFactoryGfx9: + arch = ArchTrait( + "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp32": + if dtype in ["fp32"]: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -645,10 +656,10 @@ class KernelComponentFactory: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype == "fp16" or dtype == "bf16": + elif dtype in ["fp16", "bf16"]: return { - ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -656,18 +667,18 @@ class KernelComponentFactory: FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype == "fp8" or dtype == "fp8bf16": + elif dtype in ["fp8", "fp8bf16"]: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype == "fp8fp32": + elif dtype in ["fp8fp32"]: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip @@ -680,7 +691,7 @@ class KernelComponentFactory: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] if dtype in ["fp32"]: @@ -719,18 +730,8 @@ class KernelComponentFactory: else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - if ( - (hdim, hdim_v) in [(64, 64), (128, 128)] - and logits == "f" - and bias == "no" - and dropout == "f" - and skip == "f" - ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( @@ -746,29 +747,128 @@ class KernelComponentFactory: return pipelines -class CustomFactory(KernelComponentFactory): +class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): + arch = ArchTrait("gfx950") + + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + pipelines = KernelComponentFactoryGfx9.get_pipelines( + dtype, hdim, hdim_v, receipt, mask_impl + ) + if dtype in ["fp16", "bf16"]: + squant = "f" + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + if ( + (hdim, hdim_v) in [(64, 64), (128, 128)] + and logits == "f" + and bias == "no" + and dropout == "f" + and skip == "f" + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + return pipelines + + +class KernelComponentFactoryGfx12: + arch = ArchTrait("gfx12") + @staticmethod def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + elif dtype in ["fp8", "fp8bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + elif dtype in ["fp8fp32"]: + return { + # bm0, bn0, bk0, bn1, bk1, + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + else: + return None + + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + pipelines = [] + if dtype in ["fp16", "bf16"]: + squant = "f" + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + # no need lse/dropout kernels + for logits, squant, mask, bias in itertools.product( + ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactoryGfx9): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) if dtype == "fp16" or dtype == "bf16": if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result +def get_factory(target: str): + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1": + return CustomFactory + + # Place more specific architectures first + + if target.startswith("gfx950"): + return KernelComponentFactoryGfx950 + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + if target.startswith("gfx12"): + return KernelComponentFactoryGfx12 + + raise Exception(f"Unsupported device target {target}") + + def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - factory = ( - CustomFactory - if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" - else KernelComponentFactory - ) + factories = get_factories_for_targets(targets, get_factory) - for dtype in FWD_DTYPE_MAP.keys(): + for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue @@ -791,7 +891,8 @@ def get_fwd_blobs( # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != "no" or pipeline.F_dropout == "t": continue - if dtype != "fp32": + if factory.arch.name.startswith("gfx9") and dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support if pipeline.tag != "qr_async_trload" and ( ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) @@ -811,6 +912,7 @@ def get_fwd_blobs( ): continue k = FmhaFwdKernel( + F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -918,19 +1020,33 @@ def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + api_pool, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + _, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index fcbf22fb18..ee24949cca 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass import fnmatch +import itertools +from collections import OrderedDict +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple +from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, @@ -16,16 +19,21 @@ from codegen.cpp_symbol_map import ( LAYOUT_MAP, ROPE_CHECK_MAP, ) +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file from codegen.ops.fmha_fwd import ( - FmhaFwdApiTrait, FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, ) FMHA_FWD_APPENDKV_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, @@ -55,10 +63,8 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; -#include - template<> -float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -66,31 +72,37 @@ float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fw auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp" FMHA_FWD_APPENDKV_API = """ -float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s) {{ float r = -1; + + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); + {F_dispatch} return r; }} """ -FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ - using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; - return fmha_fwd_appendkv_(s, a); - }} +FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """{F_if}((t.is_v_rowmajor == {F_vlayout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + return fmha_fwd_appendkv_(s, a); +}} """ @dataclass class FmhaFwdAppendKVApiTrait: - # sync with fmha_fwd_traits<>, to generate fallback calls + arch: ArchTrait + # sync with fmha_fwd_appendkv_traits, to generate fallback calls hdim: str dtype: str # data type bs: int # tile size along q seqlen @@ -178,62 +190,70 @@ class FmhaFwdAppendKVPipeline: class FmhaFwdAppendKVApiPool: def __init__(self, mask_impl): - self.pool = dict() + self.pool = OrderedDict() self.mask_impl = mask_impl - def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + def register_traits(self, trait: FmhaFwdAppendKVApiTrait) -> None: + hdim = trait.hdim + ts = ( + self.pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) @property def api(self) -> str: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits = self.pool[dtype][hdim] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format( - F_if=if_k, - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_scheck=trait.scheck, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_rope_check=ROPE_CHECK_MAP[trait.rope], - F_pagedkv=BOOL_MAP[trait.pagedkv], - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=ROPE_MAP[trait.rope], - F_bs=trait.bs, - F_bsk=trait.bsk, - F_bd=trait.bd, - F_bdv=trait.bdv, + per_arch = str() + for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + per_dtypes = str() + for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + per_hdim_case = str() + for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + inners = str() + for i_trait, trait in enumerate(pool_by_hdim): + inners += FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format( + F_if=if_(i_trait), + F_arch=arch, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_rope_check=ROPE_CHECK_MAP[trait.rope], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_rope=ROPE_MAP[trait.rope], + F_bs=trait.bs, + F_bsk=trait.bsk, + F_bd=trait.bd, + F_bdv=trait.bdv, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], + F_hdim_v=hdim, + F_inner_dispatch=indent(inners), ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners + per_dtypes += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + per_arch += FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes), ) - if not per_dtypes: + if not per_arch: # empty string we add some ignore to suppress warning in api - per_dtypes += " (void)t ; (void)s ; (void)a;" + per_arch = "(void)t; (void)s; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format( - F_dispatch=per_dtypes + F_dispatch=indent(per_arch) ) @@ -254,6 +274,7 @@ class FmhaFwdAppendKVTileSize: @dataclass class FmhaFwdAppendKVKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -265,6 +286,7 @@ class FmhaFwdAppendKVKernel: def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], F_bs=self.F_tile.F_bs, @@ -293,10 +315,11 @@ class FmhaFwdAppendKVKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" def api_trait(self) -> FmhaFwdAppendKVApiTrait: return FmhaFwdAppendKVApiTrait( + arch=self.F_arch, hdim=str(self.F_hdim), dtype=self.F_dtype, bs=self.F_tile.F_bs, @@ -313,31 +336,26 @@ class FmhaFwdAppendKVKernel: ) -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - "32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), - "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), - } - elif dtype == "fp8" or dtype == "bf8": - return { - "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), - } - else: - return None +class KernelComponentFactoryBase: + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + "32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + } + elif dtype in ["fp8", "bf8"]: + return { + "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + } + else: + return None - -def get_fwd_appendkv_blobs( - kernel_filter: Optional[str], receipt, mask_impl, optdim_list -) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future + @staticmethod def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later @@ -347,19 +365,18 @@ def get_fwd_appendkv_blobs( if dtype in ["fp16", "bf16"]: # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while # applying rotary embedding, so I just use 't' in inter/half pipelines - for vlayout in ["row", "col"]: - for pagedkv in ["t", "f"]: - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip + for vlayout, pagedkv in itertools.product(["row"], ["t", "f"]): + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip elif dtype in ["fp8", "bf8"]: # rope/paged-kv is not supported - pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline("row", "t", "t", "t", "t", "no", "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -367,18 +384,45 @@ def get_fwd_appendkv_blobs( assert False return pipelines + +class KernelComponentFactoryGfx9(KernelComponentFactoryBase): + arch = ArchTrait("gfx9") + + +class KernelComponentFactoryGfx12(KernelComponentFactoryBase): + arch = ArchTrait("gfx12") + + +def get_factory(target: str): + # Place more specific architectures first + + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + if target.startswith("gfx12"): + return KernelComponentFactoryGfx12 + + raise Exception(f"Unsupported device target {target}") + + +def get_fwd_appendkv_blobs( + targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list +) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: gen = list() api_pool = FmhaFwdAppendKVApiPool(mask_impl) - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) + factories = get_factories_for_targets(targets, get_factory) + + for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue for hdim_str in d.keys(): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for pipeline in factory.get_pipelines(dtype, hdim): k = FmhaFwdAppendKVKernel( + F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -418,18 +462,23 @@ def get_fwd_appendkv_blobs( def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME, api_pool.api) def write_blobs( - output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, ) -> None: api_pool, kernels = get_fwd_appendkv_blobs( - kernel_filter, receipt, mask_impl, optdim_list + targets, kernel_filter, receipt, mask_impl, optdim_list ) for kernel in kernels: write_single_kernel(kernel, output_dir) @@ -437,11 +486,16 @@ def write_blobs( def list_blobs( - file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_appendkv_blobs( - kernel_filter, receipt, mask_impl, optdim_list + targets, kernel_filter, receipt, mask_impl, optdim_list ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 31a35ecb97..85c25561ea 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -3,12 +3,14 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass import fnmatch import itertools +from collections import OrderedDict +from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union +from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( PIPELINE_ENUM_MAP, @@ -21,32 +23,29 @@ from codegen.cpp_symbol_map import ( get_mask_map, BOOL_MAP, ) +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file from codegen.ops.fmha_fwd import ( FmhaFwdTileSize, + DTYPE_BITS, + K0_MAX_SUBMAX_MAP, FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, ) -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} - -K0_MAX_SUBMAX_MAP = { - 32: 32, - 64: 64, - 96: 128, - 128: 128, - # 160: 160, - 256: 256, -} - FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr": "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", } FMHA_FWD_SPLITKV_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; @@ -113,17 +112,15 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} +}}; // struct instance +}} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; -#include - #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wtautological-compare" @@ -147,7 +144,7 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ #pragma clang diagnostic pop template<> -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // we don't check every seqlen_k values for kvcache @@ -165,14 +162,20 @@ void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, f }} template<> -std::string fmha_fwd_splitkv_get_name_() +std::string fmha_fwd_splitkv_get_name_() {{ using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ @@ -213,18 +216,16 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} +}}; // struct instance +}} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; -#include - template<> -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 8) {{ instance<3>::run(s, a); @@ -240,73 +241,79 @@ void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_conf }} template<> -std::string fmha_fwd_splitkv_combine_get_name_() +std::string fmha_fwd_splitkv_combine_get_name_() {{ using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp" FMHA_FWD_SPLITKV_API = """ #include -template +template float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if(s.log_level_ > 0) - std::cout - << ", " << fmha_fwd_splitkv_get_name_() - << ", " << fmha_fwd_splitkv_combine_get_name_() - << std::flush; + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} ); }} -float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s) {{ float r = -1; + + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); + {F_dispatch} return r; }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - // get combine kernel tile sizes - using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; - constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; - // make sure we can reuse the padding flags in combine kernels - static_assert({F_bm0} % kM0 == 0); - static_assert({F_bn1} % 32 == 0); + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % {F_bn1comb} == 0); - if (t.has_lse) {{ - if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ - return -1; - }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, true, {F_squant}, {F_spad}, {F_dvpad}>; - return fmha_fwd_splitkv_(s, a); - }} - }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + return fmha_fwd_splitkv_(s, a); + }} + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, false, {F_squant}, {F_spad}, {F_dvpad}>; - return fmha_fwd_splitkv_(s, a); - }} - }} + return fmha_fwd_splitkv_(s, a); + }} +}} """ @dataclass class FmhaFwdSplitKVApiTrait: + arch: ArchTrait pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim: str + hdim: int dtype: str # data type mode: str # value from MODE_MAP bm0: int # tile size along q seqlen (block size) @@ -326,6 +333,7 @@ class FmhaFwdSplitKVApiTrait: dpad: str dvpad: str pagedkv: str + bn1comb: int # tile size along v head_dim of combine kernel @property def name(self) -> str: @@ -523,71 +531,80 @@ class FmhaFwdSplitKVCombinePipeline: class FmhaFwdSplitKVApiPool: def __init__(self, mask_impl): - self.pool = dict() + self.pool = OrderedDict() self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdSplitKVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + hdim = trait.hdim + ts = ( + self.pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) @property def api(self) -> str: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits = self.pool[dtype][hdim] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( - F_if=if_k, - F_mode=MODE_MAP[trait.mode], - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], - F_squant=BOOL_MAP[trait.squant], - F_pagedkv=BOOL_MAP[trait.pagedkv], - F_scheck=trait.scheck, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, + per_arch = str() + for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + per_dtypes = str() + for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + per_hdim_case = str() + for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + inners = str() + for i_trait, trait in enumerate(pool_by_hdim): + inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( + F_if=if_(i_trait), + F_arch=arch, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_squant=BOOL_MAP[trait.squant], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + F_bn1comb=trait.bn1comb, + ) + per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], + F_hdim_v=hdim, + F_inner_dispatch=indent(inners), ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners + per_dtypes += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + per_arch += FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes), ) - if not per_dtypes: + if not per_arch: # empty string we add some ignore to suppress warning in api - per_dtypes += " (void)t ; (void)s ; (void)a;" + per_arch = "(void)t; (void)s; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format( - F_dispatch=per_dtypes + F_dispatch=indent(per_arch) ) @@ -605,6 +622,7 @@ class FmhaFwdSplitKVCombineTileSize: @dataclass class FmhaFwdSplitKVKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -615,8 +633,10 @@ class FmhaFwdSplitKVKernel: @property def template(self) -> str: + assert self.F_pipeline.F_lse == "t" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_tile.F_bm0, @@ -666,36 +686,12 @@ class FmhaFwdSplitKVKernel: @property def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdSplitKVApiTrait: - return FmhaFwdSplitKVApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - logits=self.F_pipeline.F_logits, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - squant=self.F_pipeline.F_squant, - pagedkv=self.F_pipeline.F_pagedkv, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - ) + return f"{self.name}{self.F_arch.filename_suffix}.cpp" @dataclass class FmhaFwdSplitKVCombineKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -707,6 +703,7 @@ class FmhaFwdSplitKVCombineKernel: def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], F_bn1=self.F_tile.F_bn1, @@ -730,85 +727,33 @@ class FmhaFwdSplitKVCombineKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - } # fmt: skip - elif dtype == "fp8" or dtype == "bf8": - return { - "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } # fmt: skip - else: - return None - - -def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - "32": FmhaFwdSplitKVCombineTileSize(32, -1), - "64": FmhaFwdSplitKVCombineTileSize(32, -1), - "96": FmhaFwdSplitKVCombineTileSize(32, -1), - "128": FmhaFwdSplitKVCombineTileSize(32, -1), - # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), - "256": FmhaFwdSplitKVCombineTileSize(32, -1), - } - elif dtype == "fp8" or dtype == "bf8": - return { - "64": FmhaFwdSplitKVCombineTileSize(32, -1), - "128": FmhaFwdSplitKVCombineTileSize(32, -1), - "256": FmhaFwdSplitKVCombineTileSize(32, -1), - } - else: - return None - - -def get_fwd_splitkv_blobs( - kernel_filter: Optional[str], receipt, mask_impl, optdim_list -) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: - Pipeline = FmhaFwdSplitKVPipeline - Kernel = FmhaFwdSplitKVKernel - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: +class KernelComponentFactoryBase: + @staticmethod + def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? + Pipeline = FmhaFwdSplitKVPipeline squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask, bias, pagedkv in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] ): - pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - - pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - - pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - - pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -816,18 +761,122 @@ def get_fwd_splitkv_blobs( assert False return pipelines - gen = list() - api_pool = FmhaFwdSplitKVApiPool(mask_impl) + @staticmethod + def get_combine_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + Pipeline = FmhaFwdSplitKVCombinePipeline + squant = "t" if dtype == "fp8" else "f" + pipelines = [] + if dtype in ["fp16", "bf16"]: + for spad, dvpad, lse in itertools.product( + ["t", "f"], ["t", "f"], ["t", "f"] + ): + pipelines.append(Pipeline("unused", spad, dvpad, lse, squant)) + elif dtype in ["fp8", "bf8"]: + # no need lse kernels + for spad, dvpad in itertools.product(["t", "f"], ["t", "f"]): + pipelines.append(Pipeline("unused", spad, dvpad, "f", squant)) + else: + assert False + return pipelines - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + @staticmethod + def get_combine_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + # Possible values of F_bn1: 8, 16, 32 + if dtype in ["fp16", "bf16"]: + return { + "32": FmhaFwdSplitKVCombineTileSize(32, -1), + "64": FmhaFwdSplitKVCombineTileSize(32, -1), + "96": FmhaFwdSplitKVCombineTileSize(32, -1), + "128": FmhaFwdSplitKVCombineTileSize(32, -1), + # "160" : FmhaFwdSplitKVCombineTileSize(32, -1), + "256": FmhaFwdSplitKVCombineTileSize(32, -1), + } + elif dtype in ["fp8", "bf8"]: + return { + "64": FmhaFwdSplitKVCombineTileSize(32, -1), + "128": FmhaFwdSplitKVCombineTileSize(32, -1), + "256": FmhaFwdSplitKVCombineTileSize(32, -1), + } + else: + return None + + +class KernelComponentFactoryGfx9(KernelComponentFactoryBase): + arch = ArchTrait("gfx9") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip + else: + return None + + +class KernelComponentFactoryGfx12(KernelComponentFactoryBase): + arch = ArchTrait("gfx12") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + +def get_factory(target: str): + # Place more specific architectures first + + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + if target.startswith("gfx12"): + return KernelComponentFactoryGfx12 + + raise Exception(f"Unsupported device target {target}") + + +def get_fwd_splitkv_blobs( + targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list +) -> List[FmhaFwdSplitKVKernel]: + Kernel = FmhaFwdSplitKVKernel + + gen = list() + + factories = get_factories_for_targets(targets, get_factory) + + for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not @@ -839,6 +888,7 @@ def get_fwd_splitkv_blobs( ): continue k = Kernel( + F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -892,55 +942,34 @@ def get_fwd_splitkv_blobs( if not cond: continue - api_pool.register_traits(k.api_trait()) gen.append(k) - return (api_pool, gen) + return gen def get_fwd_splitkv_combine_blobs( - kernel_filter: Optional[str], receipt, optdim_list + targets: List[str], kernel_filter: Optional[str], receipt, optdim_list ) -> List[FmhaFwdSplitKVCombineKernel]: - Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = "t" if dtype == "fp8" else "f" - pipelines = [] - if dtype in ["fp16", "bf16"]: - for spad, dvpad, lse in itertools.product( - ["t", "f"], ["t", "f"], ["t", "f"] - ): - pipelines.append(Pipeline("unused", spad, dvpad, lse, squant)) - elif dtype in ["fp8", "bf8"]: - # no need lse kernels - pipelines.append(Pipeline("unused", "f", "f", "f", squant)) - else: - assert False - return pipelines - gen = list() - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + factories = get_factories_for_targets(targets, get_factory) + + for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + d = factory.get_combine_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for pipeline in factory.get_combine_pipelines(dtype, hdim): if mode == "group": if pipeline.F_spad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue k = Kernel( + F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -980,43 +1009,102 @@ def get_fwd_splitkv_combine_blobs( def write_single_kernel( kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path ) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_splitkv_api(api_pool: FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME - file_path.write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME, api_pool.api) def write_blobs( - output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + filter_list: str, + receipt, + optdim_list, + mask_impl, ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) - for kernel in kernels: + combine_kernels = get_fwd_splitkv_combine_blobs( + targets, filter_list[0], receipt, optdim_list + ) + for kernel in combine_kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs( - filter_list[1], receipt, mask_impl, optdim_list + kernels = get_fwd_splitkv_blobs( + targets, filter_list[1], receipt, mask_impl, optdim_list ) for kernel in kernels: write_single_kernel(kernel, output_dir) + + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + for kernel in kernels: + combine_ks = [ + k + for k in combine_kernels + if k.F_arch == kernel.F_arch + and k.F_hdim == kernel.F_hdim + and k.F_dtype == kernel.F_dtype + and k.F_mode == kernel.F_mode + and k.F_pipeline.F_spad == kernel.F_pipeline.F_spad + and k.F_pipeline.F_dvpad == kernel.F_pipeline.F_dvpad + and k.F_pipeline.F_lse == "f" + and k.F_pipeline.F_squant == kernel.F_pipeline.F_squant + ] + assert len(combine_ks) == 1, ( + f"{len(combine_ks)} matching FmhaFwdSplitKVCombineKernel for {kernel}" + ) + combine_kernel = combine_ks[0] + api_pool.register_traits( + FmhaFwdSplitKVApiTrait( + arch=kernel.F_arch, + pipeline_tag=kernel.F_pipeline.tag, + hdim=kernel.F_hdim, + dtype=kernel.F_dtype, + mode=kernel.F_mode, + bm0=kernel.F_tile.F_bm0, + bn0=kernel.F_tile.F_bn0, + bk0=kernel.F_tile.F_bk0, + bn1=kernel.F_tile.F_bn1, + bk1=kernel.F_tile.F_bk1, + bk0max=kernel.F_tile.F_bk0max, + vlayout=kernel.F_pipeline.F_vlayout, + logits=kernel.F_pipeline.F_logits, + mask=kernel.F_pipeline.F_mask, + bias=kernel.F_pipeline.F_bias, + lse=kernel.F_pipeline.F_lse, + squant=kernel.F_pipeline.F_squant, + pagedkv=kernel.F_pipeline.F_pagedkv, + spad=kernel.F_pipeline.F_spad, + skpad=kernel.F_pipeline.F_skpad, + dpad=kernel.F_pipeline.F_dpad, + dvpad=kernel.F_pipeline.F_dvpad, + bn1comb=combine_kernel.F_tile.F_bn1, + ) + ) write_fwd_splitkv_api(api_pool, output_dir) def list_blobs( - file_path: Path, filter_list: str, receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + filter_list: str, + receipt, + optdim_list, + mask_impl, ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) with file_path.open("a") as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) + kernels = get_fwd_splitkv_combine_blobs( + targets, filter_list[0], receipt, optdim_list + ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs( - filter_list[1], receipt, mask_impl, optdim_list + kernels = get_fwd_splitkv_blobs( + targets, filter_list[1], receipt, mask_impl, optdim_list ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index f22b0fa52f..17ac129e64 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -3,12 +3,14 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass import fnmatch import itertools +from collections import OrderedDict +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple +from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, @@ -21,24 +23,27 @@ from codegen.cpp_symbol_map import ( BOOL_MAP, PIPELINE_ENUM_MAP, ) +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file +from codegen.ops.fmha_fwd import ( + DTYPE_BITS, + K0_MAX_SUBMAX_MAP, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_ARCH, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} - -K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} FMHA_FWD_PAGEDKV_PIPELINE_MAP = { "qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" } -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd.hpp" -""" - FMHA_FWD_KERNEL_BODY = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -98,10 +103,8 @@ using fmha_kernel_{F_idx} = using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; -#include - template<> -float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -109,38 +112,35 @@ float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp" FMHA_FWD_API = """ -float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s) {{ float r = -1; + + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); + {F_dispatch} return r; }} """ -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; - return fmha_fwd_pagedkv_(s, a); - }} +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); +}} """ @dataclass class FmhaFwdApiTrait: + arch: ArchTrait pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls hdim: str @@ -327,71 +327,79 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: def __init__(self, mask_impl): - self.pool = dict() + self.pool = OrderedDict() self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + hdim = trait.hdim + ts = ( + self.pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) @property def api(self) -> str: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits = self.pool[dtype][hdim] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, - F_mode=MODE_MAP[trait.mode], - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], - F_pagedkv=BOOL_MAP[trait.pagedkv], - F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], - F_scheck=trait.scheck, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, + per_arch = str() + for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + per_dtypes = str() + for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + per_hdim_case = str() + for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + inners = str() + for i_trait, trait in enumerate(pool_by_hdim): + inners += FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_(i_trait), + F_arch=arch, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], + F_hdim_v=trait.bn1, + F_inner_dispatch=indent(inners), ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners + per_dtypes += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + per_arch += FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes), ) - if not per_dtypes: + if not per_arch: # empty string we add some ignore to suppress warning in api - per_dtypes += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes) + per_arch = "(void)t; (void)s; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) @dataclass @@ -428,6 +436,7 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: + F_arch: ArchTrait F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type @@ -440,6 +449,7 @@ class FmhaFwdKernel: def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( F_idx=self.F_idx, + F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_tile.F_bm0, @@ -490,10 +500,11 @@ class FmhaFwdKernel: @property def filename(self) -> str: - return self.name + ".cpp" + return f"{self.name}{self.F_arch.filename_suffix}.cpp" def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( + arch=self.F_arch, pipeline_tag=self.F_pipeline.tag, hdim=str(self.F_hdim), dtype=self.F_dtype, @@ -519,37 +530,12 @@ class FmhaFwdKernel: ) -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - # "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } # fmt: skip - elif dtype == "fp8" or dtype == "bf8": - return { - "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } # fmt: skip - else: - return None - - -def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl -) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: +class KernelComponentFactoryBase: + @staticmethod + def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? squant = "t" if dtype == "fp8" else "f" pipelines = [] @@ -576,19 +562,85 @@ def get_fwd_blobs( assert False return pipelines + +class KernelComponentFactoryGfx9(KernelComponentFactoryBase): + arch = ArchTrait("gfx9") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip + else: + return None + + +class KernelComponentFactoryGfx12(KernelComponentFactoryBase): + arch = ArchTrait("gfx12") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + # "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + else: + return None + + +def get_factory(target: str): + # Place more specific architectures first + + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + if target.startswith("gfx12"): + return KernelComponentFactoryGfx12 + + raise Exception(f"Unsupported device target {target}") + + +def get_fwd_blobs( + targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + factories = get_factories_for_targets(targets, get_factory) + + for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - # if pipeline.F_pagedkv == 'f': + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + # if pipeline.F_pagedkv == "f": # continue if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": @@ -605,6 +657,7 @@ def get_fwd_blobs( ): continue k = FmhaFwdKernel( + F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -674,27 +727,41 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + output_dir: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + api_pool, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl + targets: List[str], + file_path: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + _, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/utils.py b/example/ck_tile/01_fmha/codegen/utils.py index e3bbb18c42..7afa4f6dd8 100644 --- a/example/ck_tile/01_fmha/codegen/utils.py +++ b/example/ck_tile/01_fmha/codegen/utils.py @@ -2,7 +2,9 @@ # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation +import dataclasses import os.path as path +import textwrap def update_file(file_path, content): @@ -19,3 +21,51 @@ def update_file(file_path, content): return with open(file_path, "w") as file: file.write(content) + + +def indent(code: str, indent: str = " ") -> str: + return textwrap.indent(code, indent) + + +def if_(i: int) -> str: + return "if" if i == 0 else "else if" + + +def check_duplicates_and_paddings(traits, trait): + """Check + * if the traits list does not contain a trait with the same parameters; + * if paddings are consitent: the previous kernel can be incorrectly called before the new one, + for example, f, _t_, f, t cannot be before f, _f_, f, t. + """ + + fields = [f.name for f in dataclasses.fields(trait)] + pad_fields = [f for f in fields if "pad" in f] + non_pad_fields = [f for f in fields if "pad" not in f] + for prev_trait in traits: + if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields): + continue + if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields): + raise Exception(f"Duplicate found {trait}") + # Check if the previous kernel can be incorrectly used before the current one + # for example, f, _t_, f, t cannot be before f, _f_, f, t + is_prev_more_restrictive = False + is_curr_more_restrictive = False + for f in pad_fields: + prev_pad = getattr(prev_trait, f) + pad = getattr(trait, f) + if isinstance(prev_pad, str): + prev_pad = 1000000 if prev_pad == "f" else 1 + pad = 1000000 if pad == "f" else 1 + elif isinstance(prev_pad, int): + prev_pad = 1000000 if prev_pad == 0 else prev_pad + pad = 1000000 if pad == 0 else pad + else: + assert False + if prev_pad < pad: + is_prev_more_restrictive = True + elif prev_pad > pad: + is_curr_more_restrictive = True + if is_prev_more_restrictive and not is_curr_more_restrictive: + raise Exception( + f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}" + ) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 570a4bed82..7b6e4a8aad 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -453,15 +453,15 @@ struct fmha_bwd_dq_dk_dv_traits_ { }; -template +template float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); -template +template void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); -template +template std::string fmha_bwd_dq_dk_dv_get_name_(); -template +template int fmha_bwd_dq_dk_dv_maxq_(); template @@ -474,13 +474,13 @@ struct fmha_bwd_dot_do_o_traits_ static constexpr bool kPadDv = kPadDv_; }; -template +template float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); -template +template void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); -template +template std::string fmha_bwd_dot_do_o_get_name_(); template +template float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); -template +template void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); -template +template std::string fmha_bwd_convert_dq_get_name_(); // This is the public API, will be generated by script diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 383be6e099..a952800806 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1159,7 +1159,7 @@ struct fmha_fwd_traits_ static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; -template +template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); template +template float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args); template +template void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); -template +template std::string fmha_fwd_splitkv_get_name_(); template +template void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); -template +template std::string fmha_fwd_splitkv_combine_get_name_(); // this is used to pattern-match internl kernel implementation, not to instantiate kernel @@ -1322,10 +1322,10 @@ struct fmha_fwd_appendkv_traits_ static constexpr bool kIsPagedKV = kIsPagedKV_; }; -template +template float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); -template +template float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args); // This is the public API, will be generated by script diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index ca3cd51c57..8a663d038d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1200,7 +1200,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - auto run_appendkv = [&](const ck_tile::stream_config& sc) { + auto run_appendkv = [&]([[maybe_unused]] const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index fce37061f6..8011c71416 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + # generate kernel instances to speed up compilation import argparse @@ -38,6 +39,7 @@ assert 0 < len(handlers) def write_blobs( + targets: List[str], output_dir: Optional[str], api_list: List[str], filters_list: List[str], @@ -54,11 +56,12 @@ def write_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) + handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated def list_blobs( + targets: List[str], output_file: Optional[str], api_list: List[str], filters_list: List[str], @@ -74,7 +77,7 @@ def list_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) + handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": @@ -82,6 +85,12 @@ if __name__ == "__main__": prog="generate", description="gen API for CK fmha kernel", ) + parser.add_argument( + "--targets", + default="gfx9,gfx950", + required=False, + help="list of GPU targets, separated by comma.", + ) parser.add_argument( "-d", "--direction", # we keep 'direction' option for backward compatibility @@ -142,6 +151,7 @@ if __name__ == "__main__": ) args = parser.parse_args() + targets = args.targets.split(",") api_list = args.direction.split(",") filter_list = args.filter.split(",") filter_list.extend([""] * (len(api_list) - len(filter_list))) @@ -149,6 +159,7 @@ if __name__ == "__main__": if args.list_blobs is not None: list_blobs( + targets, args.list_blobs, api_list, filter_list, @@ -158,6 +169,7 @@ if __name__ == "__main__": ) else: write_blobs( + targets, args.output_dir, api_list, filter_list, diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index fca6b8d0cd..02bc5476fa 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -94,7 +94,7 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } @@ -105,7 +105,7 @@ run_fp8bf16_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } @@ -114,9 +114,9 @@ run_fp8fp32_tests() { for perm in 0 1 ; do for bias in "n" "e" "a" ; do for b in 1 2 ; do - for hdim in 64 128 256 ; do + for hdim in 128 ; do - $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 903de0d581..90137331f6 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -3,6 +3,8 @@ #pragma once +#include "ck_tile/core/config.hpp" + #if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN #include "ck_tile/core/numeric/integer.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index d3405c7053..3e4f6f35be 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -3,6 +3,8 @@ #pragma once +#include "ck_tile/core/config.hpp" + #if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN #include "ck_tile/core/numeric/integer.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 31ba053796..6990fc1496 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -313,6 +313,12 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_ } // Architecture tags +struct gfx9_t +{ +}; +struct gfx950_t +{ +}; struct gfx11_t { }; diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index a7fe6b37e1..76a1c03269 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -264,7 +264,7 @@ #endif #ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN -#if __clang_major__ >= 20 +#if __clang_major__ >= 20 && !(defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) #define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1 #else #define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0 diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index e709fed23d..b17890b733 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -168,8 +168,12 @@ uint16_t float_to_bf16_rtn_asm(float f) static constexpr uint32_t FP32_NAN = 0x7fff0000; static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff; +#if defined(__GFX9__) using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); uint32x2_t check_nan; +#else + uint32_t check_nan; +#endif uint32_t tmp; asm volatile("\n \ v_cmp_u_f32 %0, %2, %2 \n \ @@ -204,8 +208,12 @@ uint16_t float_to_bf16_rta_asm(float f) const uint32_t low_nan = 0x7fff; const uint32_t hi_nan = 0x7fff0000; +#if defined(__GFX9__) using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); uint32x2_t check_nan; +#else + uint32_t check_nan; +#endif asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n" "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n" diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ca0383a57c..3729a0de5c 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -5,11 +5,8 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" -#if __clang_major__ >= 20 #include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" -#else #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#endif #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index d29afa2d98..1863192a1f 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -184,7 +184,7 @@ namespace impl { template CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) { -#if defined(__gfx94__) +#if defined(__gfx94__) || defined(__gfx12__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); @@ -195,7 +195,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" - // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and + // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin requires the old value, and // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA // so we prepare an uninitialized variable purposely, and turn off the warning int dummy_old; @@ -209,13 +209,12 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32( in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}], in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}], - dummy_old, - false); // false -> WORD0 + x, + true); // true -> WORD1 - constexpr int32_t m0 = 0x05040100; - using vec_t = array; + using vec_t = array; - vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); + vec_t d = bit_cast(y); out_dstr_tensor.get_thread_buffer().template set_as(number{}, d); }); #pragma clang diagnostic pop diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 368a0594c5..9ac0b5ba0e 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -28,6 +28,19 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif } +template +#if CK_TILE_USE_LAUNCH_BOUNDS +__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) +#endif + __global__ void kentry(Args... args) +{ +#if defined(__HIP_DEVICE_COMPILE__) + Kernel{}(args...); +#else + (..., (ignore = args, 0)); +#endif +} + // // return a anonymous functor(lambda) to be called later // the KernelImpl should be a class without non-static data member, or let's say @@ -35,11 +48,27 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -template +// Arch can be used to support linking multiple object files that have the same kernel compiled for +// different architectures. In this case each object file has to use a different tag (gfx9_t, +// gfx12_t etc.), so the kernel will have different symbols for each architecture. +// +template CK_TILE_HOST auto make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { - const auto kernel = kentry; + const auto kernel = []() { + if constexpr(std::is_void_v) + { + return kentry; + } + else + { + return kentry; + } + }(); return [=](const stream_config& s) { kernel<<>>(args...); }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 56865498c0..d991d5fe25 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -692,7 +692,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 668fab3fd3..b5bd4c74ef 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -677,7 +677,17 @@ struct FmhaBwdDQDKDVKernel return ck_tile::make_tuple(i_block, i_nhead, i_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { @@ -1171,6 +1181,21 @@ struct FmhaBwdDQDKDVKernel scale_rp_undrop, dropout); +#if defined(__gfx12__) + // Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly + // placed in divergent branches used to store padded tensors (when some lanes are + // inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors + // prevents the compiler doing this. + if constexpr(kPadHeadDimQ > 0) + { + impl::insert_dummy_dep(dk_acc_tile.get_thread_buffer()); + } + if constexpr(kPadHeadDimV > 0) + { + impl::insert_dummy_dep(dv_acc_tile.get_thread_buffer()); + } +#endif + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr); VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr); } @@ -1241,7 +1266,7 @@ struct FmhaBwdOGradDotOKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "_b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn); #undef _SS_ #undef _TS_ @@ -1371,7 +1396,7 @@ struct FmhaBwdOGradDotOKernel return ck_tile::make_tuple(i_block, i_nhead, i_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } @@ -1678,7 +1703,7 @@ struct FmhaBwdConvertQGradKernel return ck_tile::make_tuple(i_block, i_nhead, i_batch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index a82d121d62..02296513d8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -42,7 +42,7 @@ struct FmhaFwdAppendKVKernel template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on - __host__ static std::string GetName() + CK_TILE_HOST static std::string GetName() { // sync with generate.py // clang-format off @@ -143,41 +143,41 @@ struct FmhaFwdAppendKVKernel { }; - __host__ static constexpr Kargs MakeKargs(void* q_ptr, - void* k_ptr, - const void* knew_ptr, - void* v_ptr, - const void* vnew_ptr, - ck_tile::index_t seqlen_q, - const void* seqlen_k_ptr, - ck_tile::index_t seqlen_knew, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - const void* rotary_cos_ptr, - const void* rotary_sin_ptr, - ck_tile::index_t rotary_dim, - bool has_mask, - const void* block_table_ptr, - ck_tile::index_t batch_stride_block_table, - ck_tile::index_t page_block_size, - const void* cache_batch_idx, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_knew, - ck_tile::index_t stride_v, - ck_tile::index_t stride_vnew, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_knew, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_vnew, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_knew, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_vnew) + CK_TILE_HOST static constexpr Kargs MakeKargs(void* q_ptr, + void* k_ptr, + const void* knew_ptr, + void* v_ptr, + const void* vnew_ptr, + ck_tile::index_t seqlen_q, + const void* seqlen_k_ptr, + ck_tile::index_t seqlen_knew, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* rotary_cos_ptr, + const void* rotary_sin_ptr, + ck_tile::index_t rotary_dim, + bool has_mask, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + const void* cache_batch_idx, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_knew, + ck_tile::index_t stride_v, + ck_tile::index_t stride_vnew, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_knew, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_vnew, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_knew, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_vnew) { Kargs kargs{ {q_ptr, @@ -255,7 +255,7 @@ struct FmhaFwdAppendKVKernel return ck_tile::make_tuple(i_tile, i_nhead, i_batch); } - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); } CK_TILE_DEVICE void operator()(Kargs kargs) const { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index f539c9d7e9..fe7c8d48c8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1079,7 +1079,17 @@ struct FmhaFwdKernel } } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 62ac70db92..a2e6f08361 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -865,7 +865,17 @@ struct FmhaFwdPagedKVKernel } } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index a6fc0f1471..99a301f620 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -37,7 +37,7 @@ struct FmhaFwdSplitKVCombineKernel template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on - __host__ static std::string GetName() + CK_TILE_HOST static std::string GetName() { // sync with generate.py // clang-format off @@ -127,7 +127,7 @@ struct FmhaFwdSplitKVCombineKernel using Kargs = std::conditional_t; template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* lse_acc_ptr, const void* o_acc_ptr, void* lse_ptr, @@ -185,7 +185,7 @@ struct FmhaFwdSplitKVCombineKernel } template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* lse_acc_ptr, const void* o_acc_ptr, void* lse_ptr, @@ -240,8 +240,10 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v) { + // Recalculate kM0 = get_warp_size() / NThreads on host + const index_t m0 = (is_wave32() ? 32 : 64) / FmhaPipeline::Problem::NThreads; // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, m0) * ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), nhead, batch_size); @@ -266,7 +268,17 @@ struct FmhaFwdSplitKVCombineKernel return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { @@ -344,7 +356,7 @@ struct FmhaFwdSplitKVCombineKernel const auto lse_acc_dram_naive = make_naive_tensor_view( lse_acc_ptr, make_tuple(kargs.num_splits, kargs.seqlen_q), - make_tuple(kargs.split_stride_lse_acc, 1), + make_tuple(kargs.split_stride_lse_acc, number<1>{}), number{}, number<1>{}); @@ -358,11 +370,11 @@ struct FmhaFwdSplitKVCombineKernel const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), + make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, number<1>{}), number{}, number<1>{}); - // read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps + // read kNumWarps * (kM0, kN1) o_acc tiles simultaneously by kNumWarps warps const auto o_acc_dram_view = pad_tensor_view( o_acc_dram_naive, make_tuple( @@ -469,7 +481,7 @@ struct FmhaFwdSplitKVCombineKernel const auto o_dram_naive = make_naive_tensor_view( o_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.row_stride_o, 1), + make_tuple(kargs.row_stride_o, number<1>{}), number{}, number<1>{}); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 80de65ead4..a6e44c7293 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -70,7 +70,7 @@ struct FmhaFwdSplitKVKernel template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on - __host__ static std::string GetName() + CK_TILE_HOST static std::string GetName() { // sync with generate.py // clang-format off @@ -279,7 +279,7 @@ struct FmhaFwdSplitKVKernel }; template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -409,7 +409,7 @@ struct FmhaFwdSplitKVKernel } template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -574,7 +574,17 @@ struct FmhaFwdSplitKVKernel } } - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 5eac387a66..d9e19a0c7e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -7,7 +7,6 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" @@ -683,26 +682,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution() { - using AccDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kKPerBlock = Problem::kQKHeaddim; - constexpr index_t K1 = 16 / sizeof(AccDataType); - constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t K2 = GetAlignmentPostQGradAcc(); + constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size()); + constexpr index_t K0 = kKPerBlock / (K1 * K2); - constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M2 = get_warp_size() / K1; constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); constexpr auto dstr = make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence, sequence>, - tuple, sequence<2, 3>>, - tuple, sequence<2, 0>>, - sequence<1, 2, 3>, - sequence<0, 0, 1>>{}); + tile_distribution_encoding< + sequence<>, + tuple, sequence, sequence>, + tuple, sequence<2, 3>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 3, 3>, + sequence<0, 0, 0, 2>>{}); static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == kMPerBlock * kKPerBlock); return dstr; @@ -711,27 +710,25 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution() { - using AccDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kKPerBlock = Problem::kQKHeaddim; - constexpr index_t K1 = 16 / sizeof(AccDataType); - constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t K2 = GetAlignmentPostQGrad(); + constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size()); + constexpr index_t K0 = kKPerBlock / (K1 * K2); - constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M2 = get_warp_size() / K1; constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == kMPerBlock * kKPerBlock); return dstr; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp index f7ee88f906..13ef642b1b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp @@ -31,59 +31,33 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && std::is_same_v) { - static_assert(WarpGemmM == 16); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>{}; + true, // TransposeC + SwizzleA>{}; } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaF16F16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaBf16Bf16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 32); - - // TODO: hard coded here. Otherwise, it may incorrect result - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } // TODO - bf8_t }(); using BlockGemmPolicy = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7b30f36fd8..f2b524fa3d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -273,16 +273,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); } - auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution(); - auto o_acc_4_dram_window = + // First each warp processes its own part of splits + + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + auto o_acc_dram_window = make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), o_acc_dram_block_window_tmp.get_window_lengths(), o_acc_dram_block_window_tmp.get_window_origin(), - o_acc_4_dist); + o_acc_dist); - // shape=[4 * KM0, kN1] - auto o_acc_4 = make_static_distributed_tensor(o_acc_4_dist); - clear_tile(o_acc_4); + // shape=[kNumWarps * KM0, kN1] + auto o_acc = make_static_distributed_tensor(o_acc_dist); + clear_tile(o_acc); const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps; @@ -291,73 +293,73 @@ struct BlockFmhaFwdSplitKVCombinePipeline // each warp handles a [KM0, kN1] tile for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps) { - auto o_tile = load_tile(o_acc_4_dram_window); + auto o_tile = load_tile(o_acc_dram_window); const index_t i_split = split_start + get_warp_id(); const index_t row_start = kM0 * get_warp_id(); - { - constexpr auto spans = decltype(o_acc_4)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - const auto x_indices = get_x_indices_from_distributed_indices( - o_acc_4.get_tile_distribution(), i_j_idx); - - const auto row = x_indices.at(number<0>{}); - - const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split); - o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx); - }); - }); - } - - move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0}); - } - - // 4 o_acc tiles in LDS. shape=[4 * kM0, kN1] - OaccDataType* o_acc_4_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeLSEacc())); - - { - auto o_acc_4_lds_window = [&]() { - auto desc = Policy::template MakeOacc4LdsBlockDescriptor(); - auto view = make_tensor_view(o_acc_4_lds_ptr, desc); - return make_tile_window(view, desc.get_lengths(), {0, 0}); - }(); - store_tile(o_acc_4_lds_window, o_acc_4); - } - - auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); - - auto o_acc_4_lds_window = [&]() { - auto desc = Policy::template MakeOacc4LdsBlockDescriptor(); - auto view = make_tensor_view(o_acc_4_lds_ptr, desc); - return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist); - }(); - - auto o_acc = make_static_distributed_tensor(o_acc_dist); - clear_tile(o_acc); - - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - static_for<0, kNumWarps, 1>{}([&](auto) { - auto o_acc_in = load_tile(o_acc_4_lds_window); - { constexpr auto spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) += o_acc_in(i_j_idx); + const auto x_indices = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + + const auto row = x_indices.at(number<0>{}); + + const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split); + o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); }); }); } - move_tile_window(o_acc_4_lds_window, {kM0, 0}); + move_tile_window(o_acc_dram_window, {kNumWarps * kM0, 0}); + } + + // Then each warps combines partial o_acc results into one + + // kNumWarps o_acc tiles in LDS. shape=[kNumWarps * kM0, kN1] + OaccDataType* o_acc_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeLSEacc())); + + { + auto o_acc_lds_store_window = [&]() { + auto desc = Policy::template MakeOaccLdsBlockDescriptor(); + auto view = make_tensor_view(o_acc_lds_ptr, desc); + return make_tile_window(view, desc.get_lengths(), {0, 0}); + }(); + store_tile(o_acc_lds_store_window, o_acc); + } + + auto o_acc_result_dist = Policy::template MakeOaccResultDramTileDistribution(); + + auto o_acc_lds_load_window = [&]() { + auto desc = Policy::template MakeOaccLdsBlockDescriptor(); + auto view = make_tensor_view(o_acc_lds_ptr, desc); + return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_result_dist); + }(); + + auto o_acc_result = make_static_distributed_tensor(o_acc_result_dist); + clear_tile(o_acc_result); + + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + static_for<0, kNumWarps, 1>{}([&](auto) { + auto o_acc_in = load_tile(o_acc_lds_load_window); + + { + constexpr auto spans = decltype(o_acc_result)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc_result(i_j_idx) += o_acc_in(i_j_idx); + }); + }); + } + + move_tile_window(o_acc_lds_load_window, {kM0, 0}); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - - return o_acc; + return tile_elementwise_in(o_acc_element_func, o_acc_result); } template ; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNumWarps = Problem::kNumWarps; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; - constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M1 = kNumWarps; constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); constexpr index_t N0 = get_warp_size() / M2; constexpr index_t N1 = kNPerBlock / N0; @@ -78,16 +78,16 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc() { return sizeof(typename Problem::OaccDataType) * - MakeOacc4LdsBlockDescriptor().get_element_space_size(); + MakeOaccLdsBlockDescriptor().get_element_space_size(); } template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return GetSmemSizeLSEacc() + GetSmemSizeOacc4(); + return GetSmemSizeLSEacc() + GetSmemSizeOacc(); } // shape=[kMaxSplits, kM0] @@ -129,8 +129,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { using LSEDataType = remove_cvref_t; - constexpr index_t kMPerBlock = Problem::kM0; - constexpr index_t kNPerBlock = Problem::kMaxSplits; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t NPack = GetVectorSizeForTile(); @@ -142,8 +142,9 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( lse_acc_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -156,8 +157,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { using LSEDataType = remove_cvref_t; - constexpr index_t kMPerBlock = Problem::kM0; - constexpr index_t kNPerBlock = Problem::kMaxSplits; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t NPack = GetVectorSizeForTile(); @@ -169,21 +170,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( lse_acc_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0>{})); return lse_acc_t_lds_block_desc; } - // 3d + padding, shape=[4 * kM0, kN1] + // 3d + padding, shape=[kNumWarps * kM0, kN1] template - CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccLdsBlockDescriptor() { using LSEDataType = remove_cvref_t; - constexpr index_t kMPerBlock = 4 * Problem::kM0; + constexpr index_t kNumWarps = Problem::kNumWarps; + constexpr index_t kMPerBlock = kNumWarps * Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t NPack = GetVectorSizeForTile(); @@ -191,17 +194,17 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); - constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor( + constexpr auto o_acc_lds_block_desc = transform_tensor_descriptor( o_acc_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + make_tuple(sequence<0>{}, sequence<1>{})); - return o_acc_t_lds_block_desc; + return o_acc_lds_block_desc; } // shape=[kM0, kMaxSplits] @@ -235,12 +238,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy sequence<2, 1>>{}); } - // similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M - // direction + // similar to MakeOaccResultDramTileDistribution(), but duplicate same 1-warp encoding kNumWarps + // times on M direction template - CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() { - constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0) + constexpr index_t kNumWarps = Problem::kNumWarps; + constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (kNumWarps * kM0) constexpr index_t kNPerBlock = Problem::kN1; static_assert(get_warp_size() <= kMPerBlock * kNPerBlock); @@ -252,7 +256,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<3, 0>>, sequence<1, 2>, @@ -260,14 +264,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccResultDramTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNumWarps = Problem::kNumWarps; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; - static_assert(kBlockSize <= kMPerBlock * kNPerBlock); + static_assert(kNumWarps * get_warp_size() <= kMPerBlock * kNPerBlock); - constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M1 = kNumWarps; constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); constexpr index_t N0 = get_warp_size() / M2; constexpr index_t N1 = kNPerBlock / N0; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 7775848195..cc0851efb3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -204,6 +204,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem using BaseType::kM0; using BaseType::kN1; + using BaseType::NThreads; static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0); @@ -216,7 +217,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem static constexpr index_t kMaxSplits = Traits::kMaxSplits; static_assert(8 <= kMaxSplits); - static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup + static constexpr index_t kNumWarps = 4; static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); static_assert(get_warp_size() <= (kM0 * kMaxSplits) && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index f8d9973918..9e9cce5400 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -58,17 +58,6 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - using BlockGemm0 = remove_cvref_t())>; - static constexpr auto WarpGemmConfig = - BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm0 = remove_cvref_t())>; - static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>(); - static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>(); - static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM; - static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN; - static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK; - static constexpr int NumMfmaInsts = - (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp); static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -298,7 +287,18 @@ struct BlockFmhaPipelineQRKSVS // Use compile-time conditional for group barrier sequence // (No runtime lambda selection) auto schedule_gemm0 = [] { - if constexpr(kQKHeaddim == 256) + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmConfig = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0 = remove_cvref_t())>; + constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>(); + constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>(); + constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM; + constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN; + constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK; + constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * + (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp); + if constexpr(get_warp_size() == 64 && kQKHeaddim == 256) { static_assert(NumMfmaInsts % 8 == 0); static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp index 6d414ee851..575f9f106e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 050eb48384..014467fe8a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -263,59 +263,33 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && std::is_same_v) { - static_assert(WarpGemmM == 16); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>{}; + true, // TransposeC + SwizzleA>{}; } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaF16F16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaBf16Bf16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 32); - - // TODO: hard coded here. Otherwise, it may incorrect result - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } // TODO - bf8_t }(); using BlockGemmPolicy = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 9dba3c85d5..7c794a3646 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -72,59 +72,33 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && std::is_same_v) { - static_assert(WarpGemmM == 16); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>{}; + true, // TransposeC + SwizzleA>{}; } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaF16F16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaBf16Bf16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 32); - - // TODO: hard coded here. Otherwise, it may incorrect result - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } // TODO - bf8_t }(); using BlockGemmPolicy = @@ -238,7 +212,7 @@ struct BlockFmhaPipelineQXCustomPolicy BlockGemmProblem, @@ -246,59 +220,33 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && std::is_same_v) { - static_assert(WarpGemmM == 16); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>{}; + true, // TransposeC + SwizzleA>{}; } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaF16F16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); - - if constexpr(WarpGemmM == 32) - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; - else // WarpGemmM == 4 - return WarpGemmMfmaBf16Bf16F32M4N64K16{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(WarpGemmM == 32); - - // TODO: hard coded here. Otherwise, it may incorrect result - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } // TODO - bf8_t }(); using BlockGemmPolicy = @@ -481,7 +429,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); using WG = remove_cvref_t())>; - return WG::WarpGemmAttribute::Impl::kCM1PerLane; + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType); + return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane); } template @@ -1019,15 +968,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy>; auto warp_gemm = [&]() { - if constexpr(std::is_same_v && + if constexpr(get_warp_size() == 64 && + std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{}; - // return - // WarpGemmImpl>>{}; + static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32); + + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{}; } else { diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index d16651da93..f336fc7470 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -29,59 +29,40 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ? WGAttrNumAccessEnum::Double : WGAttrNumAccessEnum::Single; - if constexpr(std::is_same_v && - std::is_same_v && + if constexpr(((std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v)) && std::is_same_v) { -#if 0 - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); - - constexpr index_t NumWarp = kBlockSize / get_warp_size(); - - if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && - kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + if constexpr(get_warp_size() == 64) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2); + using WG = WarpGemmDispatcher; + return make_tuple(WG{}, 4, 1); } else { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2); + using WG = WarpGemmDispatcher; + return make_tuple(WG{}, 4, 1); } -#else - using WG = WarpGemmDispatcher; - return make_tuple(WG{}, 4, 1); -#endif - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - using WG = WarpGemmDispatcher; - return make_tuple(WG{}, 4, 1); } else { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 7ae624cafc..e774e2505f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -384,9 +384,9 @@ using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; template -using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = +using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8, 2, swizzle_factor>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 0f021c62f2..90f6204ff3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -50,6 +50,19 @@ struct CWarpDstrEncodingTrait typename Impl::kCYs2RHsMinor>; }; +template +struct CTransposedWarpDstrEncodingTrait +{ + using type = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, + tuple, + typename Impl::kCTYs2RHsMajor, + typename Impl::kCTYs2RHsMinor>; +}; + template struct WarpGemmAttributeWmma { @@ -75,9 +88,11 @@ struct WarpGemmAttributeWmma using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; - // kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input - // kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input - using CWarpDstrEncoding = typename CWarpDstrEncodingTrait::type; + // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 + using CWarpDstrEncoding = + std::conditional_t::type, + typename CWarpDstrEncodingTrait::type>; // c_vec += a_vec * b_vec template diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 88fde40067..751ada07af 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -60,6 +60,11 @@ struct WarpGemmAttributeWmmaImpl using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor; using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor; + using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor; + using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor; + using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor; + using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor; + // c_vec += a_vec * b_vec template CK_TILE_DEVICE void operator()(CVecType& c_vec, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 86bae7655b..ed5f0eb0a6 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -46,6 +46,11 @@ struct WmmaTraitsBase using kCPs2RHssMinor = sequence<1, 0>; using kCYs2RHsMajor = sequence<1, 1>; using kCYs2RHsMinor = sequence<0, 2>; + + using kCTPs2RHssMajor = sequence<2, 1>; + using kCTPs2RHssMinor = sequence<1, 0>; + using kCTYs2RHsMajor = sequence<2, 2>; + using kCTYs2RHsMinor = sequence<0, 2>; }; // GFX12 specialization @@ -88,5 +93,10 @@ struct WmmaTraitsBase using kCPs2RHssMinor = sequence<1, 0>; using kCYs2RHsMajor = sequence<1, 1>; using kCYs2RHsMinor = sequence<0, 2>; + + using kCTPs2RHssMajor = sequence<2, 1>; + using kCTPs2RHssMinor = sequence<1, 0>; + using kCTYs2RHsMajor = sequence<2, 2>; + using kCTYs2RHsMinor = sequence<0, 2>; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index cf477f7928..9037ccea6c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -1,8 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp" namespace ck_tile { diff --git a/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt index 99ed93801d..eba234cf3f 100644 --- a/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt @@ -40,7 +40,7 @@ execute_process( RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.") + message( FATAL_ERROR "CK Tile MHA FAILED to generate a list of kernels via Python.") else() file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_GEN_BLOBS) endif() @@ -74,4 +74,3 @@ add_instance_library(device_mha_instance ${device_files}) if (TARGET device_mha_instance) add_dependencies(device_mha_instance generate_cpp_files) endif() - diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index e769a79c08..6592fe4a9a 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -1,5 +1,5 @@ # Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt -if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9") +if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12") return() endif() diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index 3eea02f888..5252da8720 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -137,6 +137,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, Values(std::tuple{24, 48}, std::tuple{48, 48}, std::tuple{72, 72}, + std::tuple{40, 88}, std::tuple{96, 96}, std::tuple{120, 160}, std::tuple{256, 108}, diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index 6e4b547465..c47d039bb7 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -38,7 +38,7 @@ struct TestConfigs static constexpr auto AppendKVHDimValues = std::array{ std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; - static constexpr auto IsVRowmajorValues = std::array{false, true}; + static constexpr auto IsVRowmajorValues = std::array{true}; static constexpr bool squant = false; static constexpr bool def_lse = true; static constexpr bool def_is_v_rowmajor = true; @@ -47,24 +47,18 @@ struct TestConfigs template <> struct TestConfigs { - // Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such - // instances are added), however the corresponding tests are not disabled (they will be skipped) - // in case such instances will be added in the future. - - static constexpr auto HDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; + static constexpr auto HDimValues = + std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; static constexpr auto SplitKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; - // There are no fp8 instances with seqlen padding (mode_enum::group requires it) - static constexpr auto ModeValues = std::array{mode_enum::batch}; - static constexpr auto IsVRowmajorValues = std::array{false}; - static constexpr bool squant = true; - static constexpr bool def_lse = false; - static constexpr bool def_is_v_rowmajor = true; - static int adjust_seqlen(int seqlen) - { - // There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests - return ck_tile::integer_least_multiple(seqlen, 128); - } + static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; + static constexpr auto IsVRowmajorValues = std::array{true}; + static constexpr bool squant = true; + static constexpr bool def_lse = false; + static constexpr bool def_is_v_rowmajor = true; + // When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests: + // return ck_tile::integer_least_multiple(seqlen, 128); + static int adjust_seqlen(int seqlen) { return seqlen; } }; template <> struct TestConfigs