diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index cc66fdbfe8..860b6bf875 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -6,6 +6,7 @@ import subprocess import sys from typing import Iterable, Optional, Mapping + def gha_set_output(vars: Mapping[str, str | Path]): """Sets values in a step's output parameters. @@ -25,6 +26,7 @@ def gha_set_output(vars: Mapping[str, str | Path]): with open(step_output_file, "a") as f: f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items()) + def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: """Returns the paths of modified files relative to the base reference.""" try: @@ -42,11 +44,13 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: file=sys.stderr, ) return None - + + GITHUB_WORKFLOWS_CI_PATTERNS = [ "therock*", ] + def is_path_workflow_file_related_to_ci(path: str) -> bool: return any( fnmatch.fnmatch(path, ".github/workflows/" + pattern) @@ -56,11 +60,13 @@ def is_path_workflow_file_related_to_ci(path: str) -> bool: for pattern in GITHUB_WORKFLOWS_CI_PATTERNS ) + def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool: if paths is None: return False return any(is_path_workflow_file_related_to_ci(p) for p in paths) + # Paths matching any of these patterns are considered to have no influence over # build or test workflows so any related jobs can be skipped if all paths # modified by a commit/PR match a pattern in this list. @@ -70,23 +76,26 @@ SKIPPABLE_PATH_PATTERNS = [ "*.md", "*.pre-commit-config.*", "*LICENSE", - 'Jenkinsfile', - '.github/ISSUE_TEMPLATE/*', - '.github/CODEOWNERS', - '.github/*.md', - '.github/dependabot.yml', + "Jenkinsfile", + ".github/ISSUE_TEMPLATE/*", + ".github/CODEOWNERS", + ".github/*.md", + ".github/dependabot.yml", ] + def is_path_skippable(path: str) -> bool: """Determines if a given relative path to a file matches any skippable patterns.""" return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS) + def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool: """Returns true if at least one path is not in the skippable set.""" if paths is None: return False return any(not is_path_skippable(p) for p in paths) + def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: """Returns true if CI workflows should run given a list of modified paths.""" @@ -118,16 +127,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: ) return False + def main(args): base_ref = args.get("base_ref") modified_paths = get_modified_paths(base_ref) print("modified_paths (max 200):", modified_paths[:200]) enable_jobs = should_ci_run_given_modified_paths(modified_paths) - output = { - 'enable_therock_ci': json.dumps(enable_jobs) - } + output = {"enable_therock_ci": json.dumps(enable_jobs)} gha_set_output(output) + if __name__ == "__main__": args = {} args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1") diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000..16f7e2539c --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,16 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [develop] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.12' + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index beaabbe763..f4d0c0063c 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -35,6 +35,15 @@ jobs: with: repository: "ROCm/rocm-libraries" + - name: Pull DVC files for rocm-libraries # LOGNAME details here https://github.com/ROCm/rocm-libraries/pull/1617 + run: | + if command -v dvc &> /dev/null; then + echo "dvc detected" + else + echo "Warning, dvc not detected!" + fi + LOGNAME=github-runner dvc pull -v + - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 068dbe3033..1ccc1d57bc 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,6 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" + ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 54e068eb3d..efb5a6b1a0 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,6 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" + ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit - name: "Configuring CI options" env: diff --git a/.gitignore b/.gitignore index e4dd8f7513..bcc5888b7f 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,7 @@ tags # Editors .vscode -# build-in-source directory +# build-in-source directory (see exceptions below) build* # emacs temporary/backup files @@ -58,7 +58,7 @@ _doxygen/ docs/doxygen/html docs/doxygen/xml -# JetBrains IDE +# JetBrains IDE (see build* exceptions below) .idea/ cmake-build*/ build*/ @@ -71,3 +71,7 @@ __pycache__/ .cache/ +# Exceptions to build* patterns above +# The experimental/builder directory should be tracked despite matching build* +!experimental/builder +!experimental/builder/** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d936d3a48..04ebc6b45a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,25 @@ repos: -- repo: local +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.3 hooks: - id: clang-format - name: clang-format - entry: clang-format-18 -i --style=file - language: system types_or: [c++, inc] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 + hooks: + - id: ruff-check + args: [ --fix ] + exclude: | + (?x)^( + docs/conf.py + )$ + - id: ruff-format + exclude: | + (?x)^( + docs/conf.py + )$ +- repo: local + hooks: # - id: copyright-year-checker # name: copyright-year-checker # entry: script/check_copyright_year.sh @@ -18,21 +32,12 @@ repos: language: script types_or: [c++, text] verbose: true - - id: ruff-check - name: Ruff Linter - entry: ruff check --fix + - id: remod-ck-tile + name: Run ck_tile remod.py + entry: python script/remod_for_ck_tile.py language: python - types: [python] - additional_dependencies: [ruff] - - id: ruff-format - name: Ruff Formatter - entry: ruff format - language: python - types: [python] - additional_dependencies: [ruff] - - id: run-remod-if-ck-tile-changed - name: Run remod.py if ck_tile files changed - entry: script/remod_for_ck_tile.sh - language: script - always_run: true + files: '^(include|example)/ck_tile/.*$' + additional_dependencies: + - dos2unix + - clang-format==18.1.3 pass_filenames: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 9de78f3043..28bcaae5b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,35 +2,17 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## Composable Kernel 1.2.0 for ROCm 7.0.0 +## (Unreleased) Composable Kernel for ROCm + +### Added -### Added * Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for grouped_gemm kernels to perform multi_d elementwise operation. -* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data -* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. -* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). -* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). -* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). -* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). -* Added support for Stream-K version of mixed fp8/bf16 GEMM -* Added support for Multiple D GEMM * Added support for Multiple ABD GEMM -* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types -* Added support for FP16 2:4 structured sparsity to universal GEMM. -* Added support for Split K for grouped convolution backward data. -* Added logit soft-capping support for fMHA forward kernels. -* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) -* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) -* Added benchmarking support for tile engine GEMM. -* Added Ping-pong scheduler support for GEMM operation along the K dimension. -* Added rotating buffer feature for CK_Tile GEMM. -* Added int8 support for CK_TILE GEMM. -* Added support for elementwise kernel. * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. @@ -39,19 +21,50 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for batched contraction kernel. * Added pooling kernel in CK_TILE +### Changed + +* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) + +## Composable Kernel 1.1.0 for ROCm 7.1.0 + +### Added + +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) +* Added support for elementwise kernel. + +### Upcoming changes + +* Non-grouped convolutions are deprecated. Their functionality is supported by grouped convolution. + +## Composable Kernel 1.1.0 for ROCm 7.0.0 + +### Added + +* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). +* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). +* Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for Multiple D GEMM +* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types +* Added support for FP16 2:4 structured sparsity to universal GEMM. +* Added support for Split K for grouped convolution backward data. +* Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. +* Added rotating buffer feature for CK_Tile GEMM. +* Added int8 support for CK_TILE GEMM. + ### Optimized +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. +* Added Vectorize Transpose optimization for CK Tile +* Added the asynchronous copy for gfx950 -* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) -* Added Vectorize Transpose optimization for CK Tile (#2131) -* Added the asynchronous copy for gfx950 (#2425) - - -### Fixes - -None - -### Changes +### Changed * Removed support for gfx940 and gfx941 targets (#1944) * Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876) @@ -59,15 +72,6 @@ None * Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. -* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) - -### Known issues - -None - -### Upcoming changes - -* Non-grouped convolutions are deprecated. All of their functionality is supported by grouped convolution. ## Composable Kernel 1.1.0 for ROCm 6.1.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index f4d3a83c34..f58dff8e15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,14 @@ include(CTest) option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +if(CK_EXPERIMENTAL_BUILDER) + add_definitions(-DCK_EXPERIMENTAL_BUILDER) + include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include) +endif() + # Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" # CK Codegen requires dataclass which is added in Python 3.7 # Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 @@ -692,6 +698,10 @@ if (NOT MIOPEN_REQ_LIBS_ONLY) add_subdirectory(profiler) endif() +if (CK_EXPERIMENTAL_BUILDER) + add_subdirectory(experimental/builder) +endif() + if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 3fbcdb5849..7a8574df05 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -71,7 +71,7 @@ def shouldRunCICheck() { ''' ).trim().split('\n') - if (changedFiles.isEmpty() || (changedFiles.size() == 1 && changedFiles[0].trim().isEmpty())) { + if (changedFiles.size() == 1 && changedFiles[0] == '') { echo "No changed files detected - this might be a manual trigger or merge commit, running CI for safety" return true } @@ -909,7 +909,7 @@ def run_aiter_tests(Map conf=[:]){ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" + //sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 0c81f8df98..4fdbb896de 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -99,6 +99,9 @@ else() -Wno-unused-lambda-capture -Wno-nvcc-compat ) + if(CK_CXX_STANDARD GREATER_EQUAL 20) + list(APPEND CMAKE_COMPILER_WARNINGS -Wno-c++20-compat) + endif() else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") # cmake 3.5.2 does not support >=. diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 41e2fa2cc0..9336d47e71 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -1,3 +1,4 @@ +include_guard(GLOBAL) include(FetchContent) set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 7225dba721..7699364a7a 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -26,17 +26,18 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 128, - 128, 64, - 64, 8, 8, + 256, + 128, 256, 64, + 8, 8, 16, 16, - 4, 2, - S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 2, 8, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, - S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, - 1, 1, S<1, 32, 1, 4>, 8, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; + 1, 1, + S<1, 64, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 03c531c1ad..10dd4eaa1f 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -43,8 +43,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -190,11 +191,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 5167097b6d..556aa90f3d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -43,8 +43,9 @@ using S = ck::Sequence; using I8 = std::int8_t; using I32 = std::int32_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -190,11 +191,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index 1049b5d07c..8f8b2e80fe 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -42,8 +42,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -173,7 +174,7 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " "beta\n"); - exit(0); + exit(1); } auto f_host_tensor_descriptor = @@ -182,11 +183,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index 992e7c19c8..17e9ceccec 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -25,8 +25,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -160,23 +161,22 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; + ck::index_t StrideD = 0; + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, ELayout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - const auto StrideD = std::is_same::value - ? d_m_n.mDesc.GetStrides()[0] - : d_m_n.mDesc.GetStrides()[1]; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; diff --git a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc index 796a5d3e9b..c05e0d19aa 100644 --- a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc +++ b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -6,6 +6,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; ProblemSize ps = problem_size; // make mutable copy because default stride values of 0 need to be updated @@ -15,11 +16,11 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; @@ -43,7 +44,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - // If any user-provided leading stride <= 0, replace it with the one determined by the + // If any user-provided leading stride < 0, replace it with the one determined by the // created tensor descriptor. For RowMajor the leading stride is index 0, for ColMajor index 1. auto fetch_leading_stride = [](const auto& tensor, auto layout_tag) -> int { if constexpr(std::is_same_v) @@ -56,15 +57,15 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC } }; - if(StrideA <= 0) + if(StrideA < 0) StrideA = fetch_leading_stride(a_m_k, ALayout{}); - if(StrideB <= 0) + if(StrideB < 0) StrideB = fetch_leading_stride(b_k_n, BLayout{}); - if(StrideD0 <= 0) + if(StrideD0 < 0) StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{}); - if(StrideD1 <= 0) + if(StrideD1 < 0) StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{}); - if(StrideE <= 0) + if(StrideE < 0) StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{}); switch(config.init_method) diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 4a701e7792..f4e6b4d6e3 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -25,8 +25,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using ADataType = F16; using BDataType = F16; @@ -138,12 +139,12 @@ int main(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor( - {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); + {batch_count, row, col}, {row * stride, stride, 1_uz}, Bypass{}); } else { return HostTensorDescriptor( - {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); + {batch_count, row, col}, {col * stride, 1_uz, stride}, Bypass{}); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 182ab8d967..666f17ca08 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -31,6 +31,7 @@ struct ExecutionConfig final bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); @@ -62,12 +63,12 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, Bypass{}); } else { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, Bypass{}); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 5e56670fcf..34164b27d1 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -116,6 +116,7 @@ inline __host__ __device__ constexpr double get_atol() bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, @@ -138,12 +139,12 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if constexpr(std::is_same_v) { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, Bypass{}); } else { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, Bypass{}); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 6ed0b23407..1efbfbd540 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -37,6 +37,7 @@ struct ExecutionConfig final bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, @@ -65,12 +66,12 @@ bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionCo if(std::is_same::value) { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, Bypass{}); } else { return HostTensorDescriptor( - {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, Bypass{}); } }; diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 7a03e9cacf..40cec7ef11 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +using Bypass = ck::tensor_layout::BypassLayoutVerification; + bool run_batched_gemm_gemm_example(int argc, char* argv[]) { bool do_verification = true; @@ -111,12 +113,12 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor( - {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); + {batch_count, row, col}, {batch_stride, stride, 1_uz}, Bypass{}); } else { return HostTensorDescriptor( - {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); + {batch_count, row, col}, {batch_stride, 1_uz, stride}, Bypass{}); } }; diff --git a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc index bba6ae14a4..a3e1f325bd 100644 --- a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc +++ b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc @@ -4,27 +4,21 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; - ProblemSize ps = - problem_size; // make mutable copy because default stride values of 0 need to be updated - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { - if(std::is_same::value) - { - auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); - if(stride <= 0) - stride = desc.GetStrides()[0]; - return desc; - } - else - { - auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); - if(stride <= 0) - stride = desc.GetStrides()[1]; - return desc; - } - }; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 055d253042..63343df3a8 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -27,8 +27,9 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -110,11 +111,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 1ba8133ea7..78f7d954f0 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -26,8 +26,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -109,11 +110,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp index a30314f58c..d40d09540f 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -27,7 +27,8 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -161,11 +162,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp index 086a0f4834..102b7f50de 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -27,7 +27,8 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -157,11 +158,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp index 32345d1263..aeaa5fe776 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -24,7 +24,8 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -220,11 +221,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp index 00e2d7e33c..9363953a6e 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -27,7 +27,8 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -160,11 +161,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index 405eac7df1..a599f9d032 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -28,8 +28,9 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -121,27 +122,19 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = [](std::size_t row, - std::size_t col, - ck::index_t& stride, - auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); - if(stride <= 0) - stride = desc.GetStrides()[0]; - return desc; - } - else - { - auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); - if(stride <= 0) - stride = desc.GetStrides()[1]; - return desc; - } - }; + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index 50e670bdf3..d7e316e1e0 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -28,8 +28,9 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -121,27 +122,19 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = [](std::size_t row, - std::size_t col, - ck::index_t& stride, - auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); - if(stride <= 0) - stride = desc.GetStrides()[0]; - return desc; - } - else - { - auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); - if(stride <= 0) - stride = desc.GetStrides()[1]; - return desc; - } - }; + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 50e1c21c8f..83cc61284e 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -28,8 +28,9 @@ using BF16 = ck::bhalf_t; using I8 = int8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = BF16; using AsDataType = ck::Tuple; @@ -120,27 +121,19 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = [](std::size_t row, - std::size_t col, - ck::index_t& stride, - auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); - if(stride <= 0) - stride = desc.GetStrides()[0]; - return desc; - } - else - { - auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); - if(stride <= 0) - stride = desc.GetStrides()[1]; - return desc; - } - }; + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 8da49ef85d..43637e4a1f 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -28,8 +28,9 @@ using F16 = ck::half_t; using FP8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = FP8; using B0DataType = FP8; @@ -147,11 +148,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 72ea7f1cb6..2cb2dc17f4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -28,8 +28,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = F8; @@ -242,7 +243,7 @@ int main(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 5: N, K, tokens\n"); - exit(0); + exit(1); } ck::index_t sorted_size = sorted_tile_num * MPerBlock; @@ -294,7 +295,7 @@ int main(int argc, char* argv[]) Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index 66627a6de6..bca5ffec78 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -30,8 +30,9 @@ using F8 = ck::f8_t; using F32 = float; using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using A1DataType = F32; @@ -312,7 +313,7 @@ int main(int argc, char* argv[]) Col{})); Tensor b0_preshuffled( HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 5e306ac6dd..d14885e7f2 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -29,8 +29,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = I4; @@ -222,7 +223,7 @@ int main(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 5: N, K, tokens\n"); - exit(0); + exit(1); } if(tokens * topk > valid_size) @@ -268,10 +269,10 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b0_preshuffled( HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( @@ -289,7 +290,6 @@ int main(int argc, char* argv[]) case 0: break; case 1: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -303,7 +303,6 @@ int main(int argc, char* argv[]) break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index a6c5a8914f..d80c75abe8 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -286,7 +286,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); e_t_n_device_result.SetZero(); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index cc42c4b815..02369f344e 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -30,8 +30,9 @@ using F8 = ck::f8_t; using F32 = float; using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using A1DataType = F32; @@ -305,7 +306,7 @@ int main(int argc, char* argv[]) Col{})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); e_t_n_device_result.SetZero(); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 29e758f9d4..cafea72559 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -178,21 +178,17 @@ int main(int argc, char* argv[]) { // use default case } - else if(argc == 3) - { - // use default case - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 7) + else if(argc == 3 || argc == 7) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - tokens = std::stoi(argv[6]); + if(argc == 7) + { + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } } else { @@ -200,7 +196,7 @@ int main(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 6: N, K, tokens\n"); - exit(0); + exit(1); } ck::index_t StrideA = K; @@ -244,8 +240,8 @@ int main(int argc, char* argv[]) Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}, Bypass{})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); e_t_n_device_result.SetZero(); @@ -275,7 +271,7 @@ int main(int argc, char* argv[]) break; case 3: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -289,7 +285,7 @@ int main(int argc, char* argv[]) break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index 69c0d6558f..0c51a24679 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -31,8 +31,9 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -285,7 +286,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_k_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index 2f7762386d..b6d5d8f211 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -31,8 +31,9 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -282,7 +283,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_k_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 4ef068c41f..1adf039b70 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -32,8 +32,9 @@ using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -315,7 +316,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_k_n_host_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 317b0f9f15..61a63b47ac 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -31,8 +31,9 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -290,7 +291,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {N * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 5bb6454d2a..2670468c4b 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -31,8 +31,9 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -290,7 +291,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {N * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index 333f8a3d52..c3454be84a 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -32,8 +32,9 @@ using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F4; using A1DataType = XPackedDataType; @@ -325,7 +326,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, {N * Scale_Stride_BN, 1, Scale_Stride_BN}, Col{})); - Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}, Bypass{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 2b872cb9b5..a77d7e6be3 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -4,13 +4,28 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile ## build ``` -# in the root of ck_tile -mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -../script/cmake-ck-dev.sh ../ -make tile_example_fmha_fwd -j +# 1. In the root of composable_kernel project, create the build directory. +[~/composable_kernel] mkdir build && cd build +# 2. In the build directory, run the CMake wrapper script to generate the build system files. Replace with the gfx architectures string. +[~/composable_kernel/build] ../script/cmake-ck-dev.sh .. -G Ninja +# 3. In the build directory, run the build system recipe. +[~/composable_kernel/build] ninja tile_example_fmha_fwd ``` -This will result in an executable `build/bin/tile_example_fmha_fwd` +Running the build recipe will produce the executable `tile_example_fmha_fwd`. + +The executables reside in `bin` subdirectory of the build directory. + +This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`. + +> [!NOTE] +> `cmake-ck-dev.sh` is a CMake wrapper. +> +> The first argument is the path to composable_kernel sources. +> +> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942"). +> +> The remaining arguments are optional and are passed through to CMake. +> E.g. `-G Ninja` specifies ninja as the build system. ## kernel The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py index 03ebfd6702..483934b03b 100644 --- a/example/ck_tile/01_fmha/codegen/cmake_config.py +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -2,4 +2,4 @@ # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation -GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file +GEN_DIR = "" # in Cmake, have to generate files in same folder diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 81d34484a5..4098eb67c2 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -3,38 +3,35 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp32" : "FmhaFwdFp32", - "fp16" : "FmhaFwdFp16", - "bf16" : "FmhaFwdBf16", - "fp8" : "FmhaFwdFp8", + "fp32": "FmhaFwdFp32", + "fp16": "FmhaFwdFp16", + "bf16": "FmhaFwdBf16", + "fp8": "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", "fp8bf16": "FmhaFwdFp8Bf16", - "fp8fp32": "FmhaFwdFp8Fp32" + "fp8fp32": "FmhaFwdFp8Fp32", } -BWD_DTYPE_MAP = { - "fp32": "FmhaBwdFp32", - "fp16": "FmhaBwdFp16", - "bf16": "FmhaBwdBf16" -} +BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"} MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" + "generic": "ck_tile::GenericAttentionMask", + "simplified": "ck_tile::SimplifiedGenericAttentionMask", } _MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", + "s_no": "ck_tile::SimplifiedGenericAttentionMask", + "s_mask": "ck_tile::SimplifiedGenericAttentionMask", } _MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" + "no": "FmhaMasks::NoMask", + "causal": "FmhaMasks::CausalMask", + "generic": "FmhaMasks::GenericMask", } -def get_mask_map(mask : str): + +def get_mask_map(mask: str): if mask == "generic": return _MASK_MAP elif mask == "simplified": @@ -43,18 +40,20 @@ def get_mask_map(mask : str): assert False return None + _MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", + "no": "t.mask_type == mask_enum::no_mask", + "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic": "t.mask_type == mask_enum::window_generic", } _MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", + "s_no": "t.mask_type == mask_enum::no_mask", + "s_mask": "t.mask_type != mask_enum::no_mask", } -def get_mask_check_map(mask : str): + +def get_mask_check_map(mask: str): if mask == "generic": return _MASK_CHECK_MAP elif mask == "simplified": @@ -63,76 +62,71 @@ def get_mask_check_map(mask : str): assert False return None + BIAS_MAP = { - "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" + "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI", } # TODO: this is ugly BIAS_CHECK_MAP = { - "no" : "bias_enum::no_bias", - "bias" : "bias_enum::elementwise_bias", - "alibi" : "bias_enum::alibi" + "no": "bias_enum::no_bias", + "bias": "bias_enum::elementwise_bias", + "alibi": "bias_enum::alibi", } DROPOUT_MAP = { - "no" : "ck_tile::BlockDropoutBwd", - "dropout_wg32" : "ck_tile::BlockDropoutBwd", - "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", - "dropout_wg16" : "ck_tile::BlockDropoutBwd", - "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" + "no": "ck_tile::BlockDropoutBwd", + "dropout_wg32": "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd", + "dropout_wg16": "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd", } DROPOUT_CHECK_MAP = { - "no" : "t.has_dropout == false", - "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", - "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "no": "t.has_dropout == false", + "dropout_wg32": "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16": "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true", } ROPE_MAP = { - "no" : "ck_tile::RotaryEmbeddingEnum::NONE", - "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", - "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" + "no": "ck_tile::RotaryEmbeddingEnum::NONE", + "inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED", } ROPE_CHECK_MAP = { - "no" : "rope_enum::none", - "inter" : "rope_enum::interleaved", - "half" : "rope_enum::half_rotated" + "no": "rope_enum::none", + "inter": "rope_enum::interleaved", + "half": "rope_enum::half_rotated", } -MODE_MAP = { - "batch" : "false", - "group" : "true" -} +MODE_MAP = {"batch": "false", "group": "true"} -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} +LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", - "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs": "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", - "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", - "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { - "t" : "true", - "f" : "false", - True : "true", - False : "false", + "t": "true", + "f": "false", + True: "true", + False: "false", } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index e2f69fa49a..2e3f96e4a6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -9,28 +9,26 @@ import itertools from pathlib import Path from typing import List, Optional, Tuple -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + MODE_MAP, + LAYOUT_MAP, + BIAS_CHECK_MAP, + get_mask_check_map, + get_mask_map, + BIAS_MAP, + FWD_DTYPE_MAP, + BOOL_MAP, + PIPELINE_ENUM_MAP, +) -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 256: 256 -} +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} FMHA_BATCH_PREFILL_PIPELINE_MAP = { - "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", + "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", } FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY=""" +FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -116,8 +114,8 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b }} """ -FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" -FMHA_FWD_API=""" +FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" +FMHA_FWD_API = """ #include namespace {{ @@ -167,173 +165,223 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, }} """ -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; return fmha_batch_prefill_(s, a); }} """ + @dataclass class CppConstraint: bool_expr: str = None def __str__(self): if self.bool_expr is None: - return 'true' + return "true" else: - return f'{self.bool_expr}' + return f"{self.bool_expr}" def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') + return CppConstraint(f"({str(self)}) && ({str(other)})") + @dataclass class FmhaFwdApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - constraint : CppConstraint + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + bias: str # + lse: str # + dropout: str + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + constraint: CppConstraint @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + else: + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag in ["qr", "qr_fp8"]: + if self.skpad == "t": + return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_k % {self.bn0} == 0" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdPipeline: - tag : str + tag: str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_dropout: str # + F_squant: str # + F_mask: str # value from MASK_MAP + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + n += "_npad" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" + + if self.F_dropout == "t": + n += "_dropout" + else: + n += "_ndropout" + + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" return n + class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -344,118 +392,152 @@ class FmhaFwdApiPool: @property def api(self) -> str: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() + traits = self.pool[dtype][hdim] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_dropout=BOOL_MAP[trait.dropout], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_dtypes += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes) + @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + @dataclass class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -463,36 +545,38 @@ class FmhaFwdKernel: def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + class KernelComponentFactory: @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } + } # fmt: skip else: return None @@ -502,28 +586,38 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + if dtype in ["fp16", "bf16"]: + for logits, mask, bias, lse, dropout in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip else: assert False return pipelines + class CustomFactory(KernelComponentFactory): @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': + if dtype == "fp16" or dtype == "bf16": if 128 in result.keys(): - result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future @@ -532,30 +626,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl for dtype in FWD_DTYPE_MAP.keys(): d = CustomFactory.get_hdim_tile_size_dict(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): - for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): + for tile, pipeline in itertools.product( + tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl) + ): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if hdim == 192 and tile.F_bn1 == 128: # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + if ( + pipeline.F_bias != "no" + or pipeline.F_lse == "t" + or pipeline.F_dropout == "t" + ): continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -563,48 +668,48 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # 2 - Flash attention integration if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_batch_prefill) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_batch_prefill C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # fp32 only if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" if not cond: continue @@ -613,20 +718,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) + def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 059be0e490..d007b4caa3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -10,8 +10,18 @@ from pathlib import Path from typing import List, Tuple, Dict, Literal, Any from collections import defaultdict -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + get_mask_check_map, + BIAS_CHECK_MAP, + DROPOUT_CHECK_MAP, + MODE_MAP, + get_mask_map, + BIAS_MAP, + DROPOUT_MAP, + BWD_DTYPE_MAP, + BOOL_MAP, +) from codegen.utils import update_file @@ -21,7 +31,7 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_bwd.hpp" """ -FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile:: @@ -164,8 +174,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_() }} """ -FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" -FMHA_BWD_API=""" +FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp" +FMHA_BWD_API = """ #include template @@ -201,17 +211,18 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf }} """ -def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str: + +def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_=0) -> str: lines = [ f"{'if' if if_ == 0 else 'else if'}({F_cond})", "{", - *[' ' + line for line in F_body.split('\n') if line.strip() != ''], + *[" " + line for line in F_body.split("\n") if line.strip() != ""], "}", ] - return '\n'.join(' ' * indent + line for line in lines) + '\n' + return "\n".join(" " * indent + line for line in lines) + "\n" -FMHA_BWD_API_INNER_DISPATCH=""" +FMHA_BWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; @@ -225,6 +236,7 @@ FMHA_BWD_API_INNER_DISPATCH=""" # M0 size for 1d kernels (dot/convert) M0_1D = 64 + # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) @@ -233,174 +245,197 @@ M0_1D = 64 # Is it necessary to distinguish between K0~K4? @dataclass(frozen=True) class FmhaBwdDQDKDVTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along gemm0 unroll(F_bhdq) - F_bk1 : int # tile size along gemm1 unroll(F_bm0) - F_bk2 : int # tile size along gemm2 unroll(F_bhdv) - F_bk3 : int # tile size along gemm3 unroll(F_bm0) - F_bk4 : int # tile size along gemm4 unroll(F_bn0) - F_bhdq : int # q head_dim - F_bhdv : int # v head_dim - F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 - F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 - F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 - F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 - F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 - F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3 - F_rm2 : int # number of warps along q seqlen (block warps) in gemm4 - F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4 - F_rk2 : int # number of warps along k seqlen (not used) in gemm4 - F_wm0 : int # warp size along m in gemm0/gemm2/gemm4 - F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 - F_wk0 : int # warp size along k in gemm0/gemm2/gemm4 - F_wm1 : int # warp size along m in gemm1/gemm3 - F_wn1 : int # warp size along n in gemm1/gemm3 - F_wk1 : int # warp size along k in gemm1/gemm3 - F_occupancy : int # occupancy - max_seq_q : int = 0 + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along gemm0 unroll(F_bhdq) + F_bk1: int # tile size along gemm1 unroll(F_bm0) + F_bk2: int # tile size along gemm2 unroll(F_bhdv) + F_bk3: int # tile size along gemm3 unroll(F_bm0) + F_bk4: int # tile size along gemm4 unroll(F_bn0) + F_bhdq: int # q head_dim + F_bhdv: int # v head_dim + F_rm0: int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0: int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0: int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 + F_rm1: int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1: int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 + F_rk1: int # number of warps along q seqlen (not used) in gemm1/gemm3 + F_rm2: int # number of warps along q seqlen (block warps) in gemm4 + F_rn2: int # number of warps along headdim_qk (block warps) in gemm4 + F_rk2: int # number of warps along k seqlen (not used) in gemm4 + F_wm0: int # warp size along m in gemm0/gemm2/gemm4 + F_wn0: int # warp size along n in gemm0/gemm2/gemm4 + F_wk0: int # warp size along k in gemm0/gemm2/gemm4 + F_wm1: int # warp size along m in gemm1/gemm3 + F_wn1: int # warp size along n in gemm1/gemm3 + F_wk1: int # warp size along k in gemm1/gemm3 + F_occupancy: int # occupancy + max_seq_q: int = 0 @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" + ) + @dataclass(frozen=True) class FmhaBwdDQDKDVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaBwdDQDKDVTileSize - F_dpad : Literal[0, 8 ,1] - F_dvpad : Literal[0, 8 ,1] - F_bias : str # - F_dbias : str # - F_dropout : str # - F_mask : str # value from MASK_MAP - F_mode : str # value from MODE_MAP - F_deterministic : str # - mask_impl : str # - F_trload : str # + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_tile: FmhaBwdDQDKDVTileSize + F_dpad: Literal[0, 8, 1] + F_dvpad: Literal[0, 8, 1] + F_bias: str # + F_dbias: str # + F_dropout: str # + F_mask: str # value from MASK_MAP + F_mode: str # value from MODE_MAP + F_deterministic: str # + mask_impl: str # + F_trload: str # @property def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bk1 = self.F_tile.F_bk1, - F_bk2 = self.F_tile.F_bk2, - F_bk3 = self.F_tile.F_bk3, - F_bk4 = self.F_tile.F_bk4, - F_bhdq = self.F_tile.F_bhdq, - F_bhdv = self.F_tile.F_bhdv, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_rm2 = self.F_tile.F_rm2, - F_rn2 = self.F_tile.F_rn2, - F_rk2 = self.F_tile.F_rk2, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_dpad = self.F_dpad, - F_dvpad = self.F_dvpad, - F_bias = BIAS_MAP[self.F_bias], - F_dbias = BOOL_MAP[self.F_dbias], - F_dropout = DROPOUT_MAP[self.F_dropout], - F_occupancy = self.F_tile.F_occupancy, - F_mask = get_mask_map(self.mask_impl)[self.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_deterministic = BOOL_MAP[self.F_deterministic], - F_trload = BOOL_MAP[self.F_trload], - F_maxq = self.F_tile.max_seq_q - ) + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=BWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bk1=self.F_tile.F_bk1, + F_bk2=self.F_tile.F_bk2, + F_bk3=self.F_tile.F_bk3, + F_bk4=self.F_tile.F_bk4, + F_bhdq=self.F_tile.F_bhdq, + F_bhdv=self.F_tile.F_bhdv, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_rm2=self.F_tile.F_rm2, + F_rn2=self.F_tile.F_rn2, + F_rk2=self.F_tile.F_rk2, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_dpad=self.F_dpad, + F_dvpad=self.F_dvpad, + F_bias=BIAS_MAP[self.F_bias], + F_dbias=BOOL_MAP[self.F_dbias], + F_dropout=DROPOUT_MAP[self.F_dropout], + F_occupancy=self.F_tile.F_occupancy, + F_mask=get_mask_map(self.mask_impl)[self.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_deterministic=BOOL_MAP[self.F_deterministic], + F_trload=BOOL_MAP[self.F_trload], + F_maxq=self.F_tile.max_seq_q, + ) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_dpad : n += f'd{self.F_dpad}' - if self.F_dvpad : n += f'dv{self.F_dvpad}' - if n != '' : n = 'p' + n + n = "" + if self.F_dpad: + n += f"d{self.F_dpad}" + if self.F_dvpad: + n += f"dv{self.F_dvpad}" + if n != "": + n = "p" + n return n + pn = pad_name() n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_dbias == 't' : n += '_dbias' - else: n += '_ndbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + if pn != "": + n += f"_{pn}" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + n += "_npad" - if self.F_dropout != 'no' : n += f'_{self.F_dropout}' - else: n += '_ndropout' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' + if self.F_dbias == "t": + n += "_dbias" + else: + n += "_ndbias" - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + if self.F_dropout != "no": + n += f"_{self.F_dropout}" + else: + n += "_ndropout" + + if self.F_deterministic == "t": + n += "_deterministic" + else: + n += "_ndeterministic" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" return n @property def filename(self) -> str: return self.name + ".cpp" + # TODO: design a more practical way to do it # this is current supported tile size. -def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if dtype == 'fp32' and tr_load == 'f': +def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if dtype == "fp32" and tr_load == "f": return [ # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - ] - elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] # fmt: skip + elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f": return [ - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - ] - elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] # fmt: skip + elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t": return [ - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - - # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), - FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), - # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), - FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), - ] + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), + FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), + # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ] # fmt: skip else: return [] -FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" + +FMHA_BWD_DOT_DO_O_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_bwd_dot_do_o_trait_{F_idx} = @@ -458,47 +493,55 @@ std::string fmha_bwd_dot_do_o_get_name_() }} """ + @dataclass(frozen=True) class FmhaBwdOGradDotOKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_spad : str # true/false - F_dvpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_spad: str # true/false + F_dvpad: str # + F_mode: str # value from MODE_MAP + F_occupancy: int @property def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_spad = BOOL_MAP[self.F_spad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy) + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=BWD_DTYPE_MAP[self.F_dtype], + F_spad=BOOL_MAP[self.F_spad], + F_dvpad=BOOL_MAP[self.F_dvpad], + F_mode=MODE_MAP[self.F_mode], + F_occupancy=self.F_occupancy, + ) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - else: n += '_npad' + if pn != "": + n += f"_{pn}" + else: + n += "_npad" return n @property def filename(self) -> str: return self.name + ".cpp" -FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" + +FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_bwd_convert_dq_trait_{F_idx} = @@ -565,116 +608,133 @@ std::string fmha_bwd_convert_dq_get_name_() }} """ + @dataclass(frozen=True) class FmhaBwdConvertQGradKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_spad : str # true/false - F_dpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int # - F_deterministic : str # - disabled : bool # sometimes this kernel is not used + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_spad: str # true/false + F_dpad: str # + F_mode: str # value from MODE_MAP + F_occupancy: int # + F_deterministic: str # + disabled: bool # sometimes this kernel is not used @property def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_bm0, - F_bn0 = self.F_bn0, - F_spad = BOOL_MAP[self.F_spad], - F_dpad = BOOL_MAP[self.F_dpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy, - F_deterministic = BOOL_MAP[self.F_deterministic]) + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=BWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_bm0, + F_bn0=self.F_bn0, + F_spad=BOOL_MAP[self.F_spad], + F_dpad=BOOL_MAP[self.F_dpad], + F_mode=MODE_MAP[self.F_mode], + F_occupancy=self.F_occupancy, + F_deterministic=BOOL_MAP[self.F_deterministic], + ) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dpad == 't' : n += 'd' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_dpad == "t": + n += "d" + if n != "": + n = "p" + n return n + pn = pad_name() n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - else: n += '_npad' - if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + if self.F_deterministic == "t": + n += "_deterministic" + else: + n += "_ndeterministic" return n @property def filename(self) -> str: return self.name + ".cpp" + @dataclass(frozen=True) class FmhaBwdApiTrait: - idx : int # this is not a tunable, but a counter to differentiate symbol + idx: int # this is not a tunable, but a counter to differentiate symbol # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : int - dtype : str # data type - mode : str # value from MODE_MAP - tile : FmhaBwdDQDKDVTileSize - mask : str - bias : str - dbias : str - dropout : str - spad1d : str # spad for 1d kernels (dot/convert) - dpad : Literal[0, 1, 8] - dvpad : Literal[0, 1, 8] - deterministic : str - mask_impl : str - tr_load : str + hdim: int + dtype: str # data type + mode: str # value from MODE_MAP + tile: FmhaBwdDQDKDVTileSize + mask: str + bias: str + dbias: str + dropout: str + spad1d: str # spad for 1d kernels (dot/convert) + dpad: Literal[0, 1, 8] + dvpad: Literal[0, 1, 8] + deterministic: str + mask_impl: str + tr_load: str @property def bm0(self) -> int: return self.tile.F_bm0 + @property def bn0(self) -> int: return self.tile.F_bn0 + @property def bhdq(self) -> int: return self.tile.F_bhdq + @property def bhdv(self) -> int: return self.tile.F_bhdv @property def scheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad1d == 't': - return f'a.seqlen_q % {M0_1D} != 0' - else: # self.spad1d == 'f' - return f'a.seqlen_q % {M0_1D} == 0' + if self.mode == "group": + return "true" # always support + elif self.spad1d == "t": + return f"a.seqlen_q % {M0_1D} != 0" + else: # self.spad1d == 'f' + return f"a.seqlen_q % {M0_1D} == 0" @property def dcheck(self) -> str: - if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' - else: return f'a.hdim_q % {self.dpad} == 0' + if self.dpad == 0: + return f"a.hdim_q % {self.bhdq} == 0" + else: + return f"a.hdim_q % {self.dpad} == 0" @property def dvcheck(self) -> str: - if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' - else: return f'a.hdim_v % {self.dvpad} == 0' + if self.dvpad == 0: + return f"a.hdim_v % {self.bhdv} == 0" + else: + return f"a.hdim_v % {self.dvpad} == 0" @property def extra_cond(self) -> str: - if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: + if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: return "&& (a.seqlen_k <= 256)" else: return "" - + @property def convert_dq_bn0(self) -> int: - return self.tile.F_bn0 if self.deterministic == 't' else 0 + return self.tile.F_bn0 if self.deterministic == "t" else 0 @property def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: @@ -683,15 +743,35 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 - F_dvpad = 't' if self.dvpad else 'f' - return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad = "t" if self.dvpad else "f" + return FmhaBwdOGradDotOKernel( + F_idx=self.idx, + F_hdim=self.hdim, + F_dtype=self.dtype, + F_spad=self.spad1d, + F_dvpad=F_dvpad, + F_mode=self.mode, + F_occupancy=get_occupancy(self.dtype, self.hdim), + ) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: - return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, - F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, - F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) + return FmhaBwdDQDKDVKernel( + F_idx=self.idx, + F_hdim=self.hdim, + F_dtype=self.dtype, + F_tile=self.tile, + F_dpad=self.dpad, + F_dvpad=self.dvpad, + F_bias=self.bias, + F_dbias=self.dbias, + F_dropout=self.dropout, + F_mask=self.mask, + F_mode=self.mode, + F_deterministic=self.deterministic, + mask_impl=self.mask_impl, + F_trload=self.tr_load, + ) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -700,44 +780,76 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 - F_dpad = 't' if self.dpad else 'f' - return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, - F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), - F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) + F_dpad = "t" if self.dpad else "f" + return FmhaBwdConvertQGradKernel( + F_idx=self.idx, + F_hdim=self.hdim, + F_dtype=self.dtype, + F_bm0=M0_1D, + F_bn0=self.convert_dq_bn0, + F_spad=self.spad1d, + F_dpad=F_dpad, + F_mode=self.mode, + F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic, + disabled=self.tile.max_seq_q != 0, + ) + class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) - + self.dq_dk_dv_pool = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + ) + self.mask_impl = mask_impl - def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + def register_dq_dk_dv_traits(self, trait: FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait)) + self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][ + trait.hdim + ].append(copy.copy(trait)) @staticmethod def if_(i: int) -> str: - return 'if' if i == 0 else 'else if' + return "if" if i == 0 else "else if" def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: inners = "" - i = 0 + i = 0 for trait in traits: - inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, - F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, - F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, - F_convert_dq_bn0=trait.convert_dq_bn0) + inners += FMHA_BWD_API_INNER_DISPATCH.format( + F_if=self.if_(i), + F_mode=MODE_MAP[trait.mode], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_dbias=BOOL_MAP[trait.dbias], + F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], + F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_hdim=trait.hdim, + F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_spad1d=BOOL_MAP[trait.spad1d], + F_dpad=trait.dpad, + F_dvpad=trait.dvpad, + F_deterministic=BOOL_MAP[trait.deterministic], + F_trload=BOOL_MAP[trait.tr_load], + F_maxq=trait.tile.max_seq_q, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], + F_bn0=trait.tile.F_bn0, + F_cond_extra=trait.extra_cond, + F_convert_dq_bn0=trait.convert_dq_bn0, + ) i += 1 return inners @staticmethod def trload_sort_key(tf): - return 0 if tf == 't' else 1 # sort 't' before 'f' + return 0 if tf == "t" else 1 # sort 't' before 'f' @staticmethod def max_seq_q_sort_key(max_seq_q): @@ -746,9 +858,9 @@ class FmhaBwdApiPool: @staticmethod def max_seq_q_cond(max_seq_q: int) -> str: if max_seq_q == 0: - return 'true /* no seqlen_q limit */' + return "true /* no seqlen_q limit */" else: - return f'a.seqlen_q <= {max_seq_q}' + return f"a.seqlen_q <= {max_seq_q}" @staticmethod def dtype_cond(dtype: str) -> str: @@ -756,39 +868,56 @@ class FmhaBwdApiPool: @staticmethod def hdim_cond(hdim: int) -> str: - return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}' + return f"t.hdim_q <= {hdim} && t.hdim_v <= {hdim}" @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true /* no trload requirement */" - } - per_tr_load = '' + tr_load_cond_map = {"t": "has_load_tr", "f": "true /* no trload requirement */"} + per_tr_load = "" for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): - per_max_seq_q = '' - for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key): - per_dtypes = '' + per_max_seq_q = "" + for max_seq_q in sorted( + self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key + ): + per_dtypes = "" for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): - per_hdim_case = '' - for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]): + per_hdim_case = "" + for k, hdim in enumerate( + self.dq_dk_dv_pool[tr_load][max_seq_q][dtype] + ): traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] inners = self._api_innders(traits) - per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners) - per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case) - per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes) - per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) + per_hdim_case += FMHA_BWD_API_COND_STATEMENT( + if_=k, F_cond=self.hdim_cond(hdim), F_body=inners + ) + per_dtypes += FMHA_BWD_API_COND_STATEMENT( + if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case + ) + per_max_seq_q += FMHA_BWD_API_COND_STATEMENT( + F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes + ) + per_tr_load += FMHA_BWD_API_COND_STATEMENT( + F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4 + ) if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;' - result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) - return result.replace('\n\n', '\n') + per_tr_load += " (void)t ; (void)s ; (void)a; (void)has_load_tr;" + result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch=per_tr_load) + return result.replace("\n\n", "\n") -def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: - if filter_list == '': - filter_list = '*@*@*' - filters = filter_list.split('@') - filters.extend(['*'] * (3 - len(filters))) + +def get_bwd_blobs( + filter_list: str, receipt, mask_impl, optdim_list +) -> Tuple[ + FmhaBwdApiPool, + List[FmhaBwdOGradDotOKernel], + List[FmhaBwdDQDKDVKernel], + List[FmhaBwdConvertQGradKernel], +]: + if filter_list == "": + filter_list = "*@*@*" + filters = filter_list.split("@") + filters.extend(["*"] * (3 - len(filters))) filter_dot_do_o = filters[0] filter_convert_dq = filters[1] filter_dq_dk_dv = filters[2] @@ -803,30 +932,60 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) dpad_options = itertools.product(*([[0, 8, 1]] * 2)) tf = ["t", "f"] - for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( - tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): - assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" + for tile, mode, mask, bias, dbias, dropout, spad1d, ( + dpad, + dvpad, + ), deterministic in itertools.product( + tiles, + MODE_MAP.keys(), + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + tf, + DROPOUT_MAP.keys(), + tf, + dpad_options, + tf, + ): + assert isinstance(tile, FmhaBwdDQDKDVTileSize), ( + "tile must be FmhaBwdDQDKDVTileSize" + ) hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): continue - if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0: + if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0: continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): + if (bias == "no" or bias == "alibi") and dbias == "t": continue - if ("wg32" in dropout): + if "wg32" in dropout: continue if tr_load == "t": # tr_load can only work with 8 pad if dpad != dvpad or dpad == 1: continue - else: # tr_load == "f" + else: # tr_load == "f" # do not generate instance with only 1 of dpad/dvpad being 8 if dpad != dvpad and dpad == 8: continue if optdim_list != [-1]: if hdim not in optdim_list: continue - t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) + t = FmhaBwdApiTrait( + idx=0, + hdim=hdim, + dtype=dtype, + mode=mode, + tile=tile, + mask=mask, + bias=bias, + dbias=dbias, + dropout=dropout, + spad1d=spad1d, + dpad=dpad, + dvpad=dvpad, + deterministic=deterministic, + mask_impl=mask_impl, + tr_load=tr_load, + ) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue @@ -837,69 +996,69 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm # Flash attention integration if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond = dtype in ["fp16", "bf16"] + cond &= bias in ["no", "alibi"] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] cond &= dpad == dvpad if not cond: continue elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] + cond = dtype in ["fp16", "bf16"] + cond &= bias in ["no", "alibi"] cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond = dtype in ["fp16", "bf16"] + cond &= bias in ["no", "bias"] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue # Aiter (mha_bwd) integration elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] if not cond: continue # Aiter (mha_varlen_bwd) integration elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] if not cond: continue # aiter::mha_bwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] if not cond: continue # fp32 only, all variations if receipt == 800: - cond = dtype == 'fp32' + cond = dtype == "fp32" cond &= dpad == dvpad if not cond: continue # fp32 only, minimal set of parameters elif receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" cond &= hdim in [64, 128] cond &= dpad == dvpad - cond &= mode == 'batch' - cond &= bias == 'no' - cond &= dropout == 'no' - cond &= mask == 's_no' + cond &= mode == "batch" + cond &= bias == "no" + cond &= dropout == "no" + cond &= mask == "s_no" cond &= deterministic == "f" if not cond: continue else: # Don't build fp32 by default - if dtype == 'fp32': + if dtype == "fp32": continue gen_dot_do_o[t.dot_do_o_kernel] = True @@ -908,10 +1067,20 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm gen_convert_dq[t.convert_dq_kernel] = True api_pool.register_dq_dk_dv_traits(t) - return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) + return ( + api_pool, + list(gen_dot_do_o.keys()), + list(gen_dq_dk_dv.keys()), + list(gen_convert_dq.keys()), + ) -def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + +def write_blobs( + output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) for k in kernels_dot_do_o: update_file(output_dir / k.filename, k.template) @@ -921,7 +1090,9 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask update_file(output_dir / k.filename, k.template) -def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: +def list_blobs( + file_path: Path, filter_list: str, receipt, optdim_list, mask_impl +) -> None: _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( filter_list, receipt, mask_impl, optdim_list ) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f898d5f7b2..919a7aa8c0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -10,28 +10,25 @@ import os from pathlib import Path from typing import List, Optional, Tuple -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + LAYOUT_MAP, + BIAS_CHECK_MAP, + get_mask_check_map, + BOOL_MAP, + PIPELINE_MAP, + PIPELINE_ENUM_MAP, + MODE_MAP, + FWD_DTYPE_MAP, + BIAS_MAP, + get_mask_map, +) from codegen.utils import update_file -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 48 : 48, - 64 : 64, - 96 : 128, - 128: 128, - 192: 192, - 256: 256 -} +K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n @@ -40,7 +37,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY=""" +FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -117,8 +114,8 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) }} """ -FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" -FMHA_FWD_API=""" +FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" +FMHA_FWD_API = """ #include #include @@ -172,197 +169,254 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& }} """ -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ {F_dtype_case} }} """ -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ + @dataclass class CppConstraint: bool_expr: str = None def __str__(self): if self.bool_expr is None: - return 'true' + return "true" else: - return f'{self.bool_expr}' + return f"{self.bool_expr}" def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') + return CppConstraint(f"({str(self)}) && ({str(other)})") + @dataclass class FmhaFwdApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str - tr_load : str - constraint : CppConstraint + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + bias: str # + lse: str # + dropout: str + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + skip: str + tr_load: str + constraint: CppConstraint @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async', 'qr_async_trload']: - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - def seqtune(self, max_bm0 : int) -> str: - if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr", "qs"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" else: - return f'a.seqlen_q <= {self.bm0}' + assert False + + def seqtune(self, max_bm0: int) -> str: + if self.bm0 == max_bm0: + return "true/*fall back to largest tile*/" + else: + return f"a.seqlen_q <= {self.bm0}" @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' - elif self.pipeline_tag in ['qr', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" + else: + return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + elif self.pipeline_tag in ["qr", "qs"]: + if self.skpad == "t": + return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + elif self.pipeline_tag == "qr_async_trload": + if self.skpad == "t": + return "true" + else: + return "true" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdPipeline: - tag : str + tag: str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false - F_trload : str # true/false - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_dropout: str # + F_squant: str # + F_mask: str # value from MASK_MAP + F_skip: str # true/false + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + n += "_npad" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' + if self.F_dropout == "t": + n += "_dropout" + else: + n += "_ndropout" + + if self.F_skip == "t": + n += "_skip" + else: + n += "_nskip" + + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" return n + class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -374,130 +428,171 @@ class FmhaFwdApiPool: @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - per_tr_load =str() + per_tr_load = str() for tr_load in ["t", "f"]: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] max_bm0 = max((t.bm0 for t in traits), default=0) - inners=str() + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_dropout=BOOL_MAP[trait.dropout], + F_skip=BOOL_MAP[trait.skip], + F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune(max_bm0), + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + @dataclass class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -505,75 +600,77 @@ class FmhaFwdKernel: def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + class KernelComponentFactory: # TODO: design a more practical way to do it # this is current supported tile size per hdim @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp32': + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp32": return { # bm0, bn0, bk0, bn1, bk1, - ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - } - elif dtype == 'fp16' or dtype == 'bf16': + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip + elif dtype == "fp16" or dtype == "bf16": return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - elif dtype == 'fp8' or dtype == 'fp8bf16': + ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } # fmt: skip + elif dtype == "fp8" or dtype == "fp8bf16": return { - (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - } - elif dtype == 'fp8fp32': + ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } # fmt: skip + elif dtype == "fp8fp32": return { - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - } + (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } # fmt: skip else: return None @@ -586,95 +683,143 @@ class KernelComponentFactory: # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ['fp32']: - squant = 'f' - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - elif dtype in ['fp16', 'bf16']: - squant = 'f' - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if dtype in ["fp32"]: + squant = "f" + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + elif dtype in ["fp16", "bf16"]: + squant = "f" + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + if ( + (hdim, hdim_v) in [(64, 64), (128, 128)] + and logits == "f" + and bias == "no" + and dropout == "f" + and skip == "f" + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: + pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim + elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels - for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'bf8']: + for logits, squant, mask, bias in itertools.product( + ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + elif dtype in ["fp8fp16", "bf8"]: # TODO None else: assert False return pipelines + class CustomFactory(KernelComponentFactory): @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': + if dtype == "fp16" or dtype == "bf16": if (128, 128) in result.keys(): - result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) for dtype in FWD_DTYPE_MAP.keys(): d = factory.get_hdim_tile_size_dict(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product( + d.items(), MODE_MAP.keys() + ): for tile, next_tile in zip(tiles, tiles[1:]): - assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' - for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + assert next_tile.F_bm0 >= tile.F_bm0, ( + "Tiles must be ordered by increasing bm0" + ) + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if (hdim, hdim_v) == (192, 128): # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + if pipeline.F_bias != "no" or pipeline.F_dropout == "t": continue - if dtype != 'fp32': - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + if dtype != "fp32": + if pipeline.tag != "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) + or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) + ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + if pipeline.tag == "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) + or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) + ): continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -682,80 +827,80 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # 2 - Flash attention integration if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" + cond &= pipeline.F_skip == "f" + cond &= pipeline.F_logits == "f" if not cond: continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if dtype == "fp8bf16": cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if dtype == "fp8bf16": cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= pipeline.F_vlayout == "row" + if dtype == "fp8bf16": cond &= hdim == 128 if not cond: continue elif receipt == 888: - cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] - cond &= pipeline.F_vlayout == 'row' + cond = dtype in ["fp8", "fp8bf16", "fp8fp32"] + cond &= pipeline.F_vlayout == "row" cond &= hdim == 128 if not cond: continue # fp32 only, all variations if receipt == 800: - cond = dtype == 'fp32' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' + cond = dtype == "fp32" + cond &= pipeline.F_skip == "f" + cond &= pipeline.F_logits == "f" if not cond: continue # fp32 only, minimal set of parameters elif receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" cond &= hdim in [48, 128] - cond &= mode == 'batch' - cond &= pipeline.F_bias == 'no' - cond &= pipeline.F_lse == 'f' - cond &= pipeline.F_dropout == 'f' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' - cond &= pipeline.F_mask == 's_no' + cond &= mode == "batch" + cond &= pipeline.F_bias == "no" + cond &= pipeline.F_lse == "f" + cond &= pipeline.F_dropout == "f" + cond &= pipeline.F_skip == "f" + cond &= pipeline.F_logits == "f" + cond &= pipeline.F_mask == "s_no" if not cond: continue else: # Don't build fp32 by default - if dtype == 'fp32': + if dtype == "fp32": continue api_pool.register_traits(k.api_trait()) @@ -763,20 +908,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) + def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: update_file(autogen_dir / kernel.filename, kernel.template) -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 38491b56c4..fcbf22fb18 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -5,23 +5,27 @@ import copy from dataclasses import dataclass import fnmatch -import itertools from pathlib import Path from typing import List, Optional, Tuple -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + FWD_DTYPE_MAP, + BOOL_MAP, + ROPE_MAP, + LAYOUT_MAP, + ROPE_CHECK_MAP, +) from codegen.ops.fmha_fwd import ( FmhaFwdApiTrait, - DTYPE_BITS, FMHA_FWD_KERNEL_HEADER, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, ) -FMHA_FWD_APPENDKV_KERNEL_BODY=""" +FMHA_FWD_APPENDKV_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, @@ -66,8 +70,8 @@ float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fw }} """ -FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp" -FMHA_FWD_APPENDKV_API=""" +FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp" +FMHA_FWD_APPENDKV_API = """ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} @@ -75,7 +79,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co }} """ -FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && +FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; @@ -83,81 +87,101 @@ FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == { }} """ + @dataclass class FmhaFwdAppendKVApiTrait: # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - bs : int # tile size along q seqlen - bsk : int # tile size along k seqlen - bd : int # tile size along qk gemm unroll - bdv : int # tile size along kv gemm unroll - vlayout : str - spad : str - skpad : str - dpad : str - dvpad : str - rope : str # key from ROPE_MAP - pagedkv : str + hdim: str + dtype: str # data type + bs: int # tile size along q seqlen + bsk: int # tile size along k seqlen + bd: int # tile size along qk gemm unroll + bdv: int # tile size along kv gemm unroll + vlayout: str + spad: str + skpad: str + dpad: str + dvpad: str + rope: str # key from ROPE_MAP + pagedkv: str @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ - f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' + return ( + f"{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-" + + f"{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}" + ) @property def scheck(self) -> str: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' - else : return f'a.seqlen_q % {self.bs} == 0' + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bs} != 0*/" + else: + return f"a.seqlen_q % {self.bs} == 0" @property def skcheck(self) -> str: # we do not check all the values in a.seqlen_k_ptr - return 'true' + return "true" @property def dcheck(self) -> str: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bd} == 0' + if self.dpad == "t": + return f"true /*a.hdim_q % {self.bd} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {self.bd} == 0" @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bdv} == 0' + if self.dvpad == "t": + return f"true /*a.hdim_v % {self.bdv} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {self.bdv} == 0" + @dataclass class FmhaFwdAppendKVPipeline: - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_rope : str # key from ROPE_MAP - F_pagedkv : str # t/f + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_rope: str # key from ROPE_MAP + F_pagedkv: str # t/f @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - if self.F_rope != 'no': n += f'_{self.F_rope}' - if self.F_pagedkv == 't': n += '_pagedkv' + n = f"v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + if self.F_rope != "no": + n += f"_{self.F_rope}" + if self.F_pagedkv == "t": + n += "_pagedkv" return n + class FmhaFwdAppendKVApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -168,74 +192,104 @@ class FmhaFwdAppendKVApiPool: @property def api(self) -> str: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() + traits = self.pool[dtype][hdim] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], - F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_rope_check=ROPE_CHECK_MAP[trait.rope], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_rope=ROPE_MAP[trait.rope], + F_bs=trait.bs, + F_bsk=trait.bsk, + F_bd=trait.bd, + F_bdv=trait.bdv, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) + per_dtypes += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format( + F_dispatch=per_dtypes + ) + @dataclass class FmhaFwdAppendKVTileSize: - F_bs : int # tile size along q seqlen - F_bsk : int # tile size along k seqlen - F_bd : int # tile size along qk gemm unroll - F_bdv : int # tile size along kv gemm unroll - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bs: int # tile size along q seqlen + F_bsk: int # tile size along k seqlen + F_bd: int # tile size along qk gemm unroll + F_bdv: int # tile size along kv gemm unroll + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property def name(self) -> str: - return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" + ( + "" if self.F_occupancy == -1 else f"_o{self.F_occupancy}" + ) + @dataclass class FmhaFwdAppendKVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaFwdAppendKVTileSize - F_pipeline : FmhaFwdAppendKVPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_tile: FmhaFwdAppendKVTileSize + F_pipeline: FmhaFwdAppendKVPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_APPENDKV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bs = self.F_tile.F_bs, - F_bsk = self.F_tile.F_bsk, - F_bd = self.F_tile.F_bd, - F_bdv = self.F_tile.F_bdv, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_rope = ROPE_MAP[self.F_pipeline.F_rope], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bs=self.F_tile.F_bs, + F_bsk=self.F_tile.F_bsk, + F_bd=self.F_tile.F_bd, + F_bdv=self.F_tile.F_bdv, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_rope=ROPE_MAP[self.F_pipeline.F_rope], + F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy=self.F_tile.F_occupancy, + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -243,40 +297,45 @@ class FmhaFwdAppendKVKernel: def api_trait(self) -> FmhaFwdAppendKVApiTrait: return FmhaFwdAppendKVApiTrait( - hdim=str(self.F_hdim), - dtype=self.F_dtype, - bs=self.F_tile.F_bs, - bsk=self.F_tile.F_bsk, - bd=self.F_tile.F_bd, - bdv=self.F_tile.F_bdv, - vlayout=self.F_pipeline.F_vlayout, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - rope=self.F_pipeline.F_rope, - pagedkv=self.F_pipeline.F_pagedkv) + hdim=str(self.F_hdim), + dtype=self.F_dtype, + bs=self.F_tile.F_bs, + bsk=self.F_tile.F_bsk, + bd=self.F_tile.F_bd, + bdv=self.F_tile.F_bdv, + vlayout=self.F_pipeline.F_vlayout, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + rope=self.F_pipeline.F_rope, + pagedkv=self.F_pipeline.F_pagedkv, + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim -def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': +def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { - '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), - '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + "32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == "fp8" or dtype == "bf8": return { - '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) + "64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + "128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + "256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), } else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: + +def get_fwd_appendkv_blobs( + kernel_filter: Optional[str], receipt, mask_impl, optdim_list +) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -284,25 +343,24 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ["fp16", "bf16"]: # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while # applying rotary embedding, so I just use 't' in inter/half pipelines - for vlayout in ['row', 'col']: + for vlayout in ["row", "col"]: for pagedkv in ["t", "f"]: - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) - elif dtype in ['fp8', 'bf8']: + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip + elif dtype in ["fp8", "bf8"]: # rope/paged-kv is not supported - pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip + elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None else: @@ -314,19 +372,21 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) - if d == None: + if d is None: continue for hdim_str in d.keys(): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - k = FmhaFwdAppendKVKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdAppendKVKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -334,20 +394,20 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op continue # 2 - Flash attention integration if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" if not cond: continue # fp32 only if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" if not cond: continue @@ -356,21 +416,33 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op return (api_pool, gen) + def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: + +def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) + +def write_blobs( + output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_appendkv_blobs( + kernel_filter, receipt, mask_impl, optdim_list + ) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) + +def list_blobs( + file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_appendkv_blobs( + kernel_filter, receipt, mask_impl, optdim_list + ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 281357ef1e..31a35ecb97 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -9,41 +9,44 @@ import itertools from pathlib import Path from typing import List, Optional, Tuple, Union -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + PIPELINE_ENUM_MAP, + get_mask_check_map, + LAYOUT_MAP, + BIAS_CHECK_MAP, + MODE_MAP, + FWD_DTYPE_MAP, + BIAS_MAP, + get_mask_map, + BOOL_MAP, +) from codegen.ops.fmha_fwd import ( FmhaFwdTileSize, - FmhaFwdApiTrait, FMHA_FWD_KERNEL_HEADER, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, ) -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, + 32: 32, + 64: 64, + 96: 128, 128: 128, # 160: 160, - 256: 256 + 256: 256, } FMHA_FWD_SPLITKV_PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", - "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", + "qr": "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", } -FMHA_FWD_SPLITKV_KERNEL_BODY=""" +FMHA_FWD_SPLITKV_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; @@ -169,7 +172,7 @@ std::string fmha_fwd_splitkv_get_name_() }} """ -FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ @@ -244,8 +247,8 @@ std::string fmha_fwd_splitkv_combine_get_name_() }} """ -FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" -FMHA_FWD_SPLITKV_API=""" +FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API = """ #include template @@ -270,7 +273,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; @@ -298,172 +301,232 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F }} """ + @dataclass class FmhaFwdSplitKVApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - mask : str - logits : str - bias : str # - lse : str # - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - pagedkv : str + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + mask: str + logits: str + bias: str # + lse: str # + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + pagedkv: str @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ - f'{self.dvpad}-{self.pagedkv}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" + + f"{self.dvpad}-{self.pagedkv}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr", "qr_nwarp_sshuffle"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + else: + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag in ["qr", "qr_nwarp_sshuffle"]: + if self.skpad == "t": + return f"true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qr_nwarp_sshuffle"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qr_nwarp_sshuffle"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdSplitKVPipeline: - tag : str + tag: str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_squant : str # - F_pagedkv : str # t/f - F_mask : str # value from MASK_MAP + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_squant: str # + F_pagedkv: str # t/f + F_mask: str # value from MASK_MAP @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + n += "_npad" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_pagedkv == 't' : n += '_pagedkv' - else: n += '_npagedkv' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" + + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" + + if self.F_pagedkv == "t": + n += "_pagedkv" + else: + n += "_npagedkv" return n + @dataclass class FmhaFwdSplitKVCombinePipeline: - tag : str + tag: str - F_spad : str # true/false - F_dvpad : str # - F_lse : str # - F_squant : str # + F_spad: str # true/false + F_dvpad: str # + F_lse: str # + F_squant: str # @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' + n = f"{self.tag}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" return n + class FmhaFwdSplitKVApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: + def register_traits(self, trait: FmhaFwdSplitKVApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -474,97 +537,132 @@ class FmhaFwdSplitKVApiPool: @property def api(self) -> str: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() + traits = self.pool[dtype][hdim] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_squant=BOOL_MAP[trait.squant], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + per_dtypes += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format( + F_dispatch=per_dtypes + ) + @dataclass class FmhaFwdSplitKVCombineTileSize: - F_bn1 : int # tile size along v head_dim - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bn1: int # tile size along v head_dim + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property def name(self) -> str: - return f"b{self.F_bn1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return f"b{self.F_bn1}" + ( + "" if self.F_occupancy == -1 else f"_o{self.F_occupancy}" + ) + @dataclass class FmhaFwdSplitKVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdSplitKVPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdSplitKVPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_SPLITKV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -572,103 +670,113 @@ class FmhaFwdSplitKVKernel: def api_trait(self) -> FmhaFwdSplitKVApiTrait: return FmhaFwdSplitKVApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - logits=self.F_pipeline.F_logits, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - squant=self.F_pipeline.F_squant, - pagedkv=self.F_pipeline.F_pagedkv, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + squant=self.F_pipeline.F_squant, + pagedkv=self.F_pipeline.F_pagedkv, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + ) + @dataclass class FmhaFwdSplitKVCombineKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdSplitKVCombineTileSize - F_pipeline : FmhaFwdSplitKVCombinePipeline + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdSplitKVCombineTileSize + F_pipeline: FmhaFwdSplitKVCombinePipeline @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bn1 = self.F_tile.F_bn1, - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_mode = MODE_MAP[self.F_mode]) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bn1=self.F_tile.F_bn1, + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy=self.F_tile.F_occupancy, + F_mode=MODE_MAP[self.F_mode], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: return self.name + ".cpp" + # TODO: design a more practical way to do it # this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': +def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': + "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip + elif dtype == "fp8" or dtype == "bf8": return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } + "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip else: return None -def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { - '32' : FmhaFwdSplitKVCombineTileSize(32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - '96' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + "32": FmhaFwdSplitKVCombineTileSize(32, -1), + "64": FmhaFwdSplitKVCombineTileSize(32, -1), + "96": FmhaFwdSplitKVCombineTileSize(32, -1), + "128": FmhaFwdSplitKVCombineTileSize(32, -1), # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': + "256": FmhaFwdSplitKVCombineTileSize(32, -1), + } + elif dtype == "fp8" or dtype == "bf8": return { - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, -1), + "64": FmhaFwdSplitKVCombineTileSize(32, -1), + "128": FmhaFwdSplitKVCombineTileSize(32, -1), + "256": FmhaFwdSplitKVCombineTileSize(32, -1), } else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + +def get_fwd_splitkv_blobs( + kernel_filter: Optional[str], receipt, mask_impl, optdim_list +) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -679,25 +787,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + if dtype in ["fp16", "bf16"]: + for logits, mask, bias, pagedkv in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] + ): + pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) - elif dtype in ['fp8', 'bf8']: - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) - elif dtype in ['fp8fp16', 'fp8bf16']: + pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + elif dtype in ["fp8", "bf8"]: + for logits, mask, bias in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None else: @@ -709,28 +821,33 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = Kernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -738,40 +855,40 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt continue # Flash attention integration if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16, bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' + cond = dtype in ["fp16, bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] cond &= mode == "group" - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd_splikv C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # fp32 only if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" if not cond: continue @@ -780,7 +897,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: + +def get_fwd_splitkv_combine_blobs( + kernel_filter: Optional[str], receipt, optdim_list +) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -791,14 +911,16 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): - pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) - elif dtype in ['fp8', 'bf8']: + if dtype in ["fp16", "bf16"]: + for spad, dvpad, lse in itertools.product( + ["t", "f"], ["t", "f"], ["t", "f"] + ): + pipelines.append(Pipeline("unused", spad, dvpad, lse, squant)) + elif dtype in ["fp8", "bf8"]: # no need lse kernels - pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + pipelines.append(Pipeline("unused", "f", "f", "f", squant)) else: assert False return pipelines @@ -807,24 +929,26 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): if mode == "group": - if pipeline.F_spad != 't': + if pipeline.F_spad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline) - if kernel_filter != '': + k = Kernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -832,19 +956,19 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim continue # Aiter(mha_varlen_fwd) integration if receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] cond &= mode == "group" if not cond: continue # aiter::mha_fwd_splikv C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ["fp16", "bf16"] if not cond: continue # fp32 only if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" if not cond: continue @@ -852,34 +976,48 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim return gen -def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + +def write_single_kernel( + kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path +) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + +def write_fwd_splitkv_api(api_pool: FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) + +def write_blobs( + output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl +) -> None: + filter_list = filter_list.split("@") + filter_list.extend([""] * (2 - len(filter_list))) kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) + api_pool, kernels = get_fwd_splitkv_blobs( + filter_list[1], receipt, mask_impl, optdim_list + ) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) - with file_path.open('a') as f: +def list_blobs( + file_path: Path, filter_list: str, receipt, optdim_list, mask_impl +) -> None: + filter_list = filter_list.split("@") + filter_list.extend([""] * (2 - len(filter_list))) + + with file_path.open("a") as f: kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) + _, kernels = get_fwd_splitkv_blobs( + filter_list[1], receipt, mask_impl, optdim_list + ) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 3624b7b387..f22b0fa52f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -9,28 +9,26 @@ import itertools from pathlib import Path from typing import List, Optional, Tuple -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + LAYOUT_MAP, + BIAS_CHECK_MAP, + get_mask_check_map, + MODE_MAP, + get_mask_map, + BIAS_MAP, + FWD_DTYPE_MAP, + BOOL_MAP, + PIPELINE_ENUM_MAP, +) -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 256: 256 -} +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} FMHA_FWD_PAGEDKV_PIPELINE_MAP = { - "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" + "qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" } FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY=""" +FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -115,8 +113,8 @@ float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd }} """ -FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" -FMHA_FWD_API=""" +FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API = """ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} @@ -124,164 +122,215 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con }} """ -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_pagedkv_(s, a); }} """ + @dataclass class FmhaFwdApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - pagedkv : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + bias: str # + lse: str # + pagedkv: str + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + skip: str @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr_pagedkv", "qs"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + else: + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag in ["qr_pagedkv", "qs"]: + if self.skpad == "t": + return f"true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr_pagedkv", "qs"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr_pagedkv", "qs"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdPipeline: - tag : str + tag: str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_pagedkv : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_pagedkv: str # + F_squant: str # + F_mask: str # value from MASK_MAP + F_skip: str # true/false @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + n += "_npad" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" - if self.F_pagedkv == 't' : n += '_pagedkv' - else: n += '_npagedkv' + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" + + if self.F_skip == "t": + n += "_skip" + else: + n += "_nskip" + + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" + + if self.F_pagedkv == "t": + n += "_pagedkv" + else: + n += "_npagedkv" return n + class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -292,117 +341,152 @@ class FmhaFwdApiPool: @property def api(self) -> str: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() + traits = self.pool[dtype][hdim] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_pagedkv=BOOL_MAP[trait.pagedkv], + F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_dtypes += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes) + @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + @dataclass class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -410,51 +494,56 @@ class FmhaFwdKernel: def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - pagedkv=self.F_pipeline.F_pagedkv, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': +def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { - # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': + # "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } # fmt: skip + elif dtype == "fp8" or dtype == "bf8": return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } + "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -462,20 +551,27 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - elif dtype in ['fp8', 'bf8']: + if dtype in ["fp16", "bf16"]: + for logits, mask, bias, pagedkv, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t"], + ["f"], + ): + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: - # TODO - None + for logits, mask, bias in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + elif dtype in ["fp8fp16", "fp8bf16"]: + pass # TODO else: assert False return pipelines @@ -485,9 +581,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) @@ -495,24 +591,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # if pipeline.F_pagedkv == 'f': # continue if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if hdim == 192 and tile.F_bn1 == 128: # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + if pipeline.F_bias != "no" or pipeline.F_lse == "t": continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -520,49 +621,49 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # 2 - Flash attention integration if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_skip == "f" if not cond: continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # fp32 only if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' + cond = dtype == "fp32" if not cond: continue @@ -571,20 +672,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) + def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0317330511..fce37061f6 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -6,30 +6,45 @@ import argparse from enum import IntEnum from pathlib import Path import pkgutil -import sys from typing import List, Optional import codegen.ops -from codegen.cmake_config import * +from codegen.cmake_config import GEN_DIR class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 + # inspect all modules under 'codegen.ops' and register API handlers ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): - full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + full_module_name = "%s.%s" % (codegen.ops.__name__, module_name) ops.append(importer.find_spec(module_name).loader.load_module(module_name)) -unwanted_prefix = 'fmha_' +unwanted_prefix = "fmha_" handlers = dict( - [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, - (op.list_blobs, op.write_blobs)) for op in ops] + [ + ( + op.__name__[len(unwanted_prefix) :] + if op.__name__.startswith(unwanted_prefix) + else op.__name__, + (op.list_blobs, op.write_blobs), + ) + for op in ops + ] ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + +def write_blobs( + output_dir: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -41,8 +56,16 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : handler = handlers[api][HandlerId.WRITE_BLOBS] handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) + # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: +def list_blobs( + output_file: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: assert output_file is not None file_path = Path(output_file) @@ -53,6 +76,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : handler = handlers[api][HandlerId.LIST_BLOBS] handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", @@ -60,32 +84,29 @@ if __name__ == "__main__": ) parser.add_argument( "-d", - "--direction", # we keep 'direction' option for backward compatibility + "--direction", # we keep 'direction' option for backward compatibility "-a", "--api", - default='fwd', + default="fwd", required=False, - help="supply API(s) to generate (default: fwd). separated by comma." + help="supply API(s) to generate (default: fwd). separated by comma.", ) parser.add_argument( "-o", "--output_dir", required=False, - help="write all the blobs into a directory" + help="write all the blobs into a directory", ) parser.add_argument( - "-l", - "--list_blobs", - required=False, - help="list all the kernels to a file" + "-l", "--list_blobs", required=False, help="list all the kernels to a file" ) # TODO: if using filter, must apply same value to output_dir and list_blobs parser.add_argument( "-f", "--filter", - default='', + default="", required=False, - help="filter out kernels that need to generate, using fnmatch module" + help="filter out kernels that need to generate, using fnmatch module", ) parser.add_argument( @@ -93,7 +114,7 @@ if __name__ == "__main__": "--mask", default="simplified", required=False, - help="mask implementation, simplified/generic" + help="mask implementation, simplified/generic", ) parser.add_argument( @@ -101,32 +122,46 @@ if __name__ == "__main__": "--receipt", default=0, required=False, - help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration\n" + \ - " 4: Only generate instance for PyTorch integration\n" + \ - " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ - " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ - " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ - " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + help="codegen receipt. 0: generate only 8xhdim coverage\n" + + " 1: generate more instance to cover all hdim\n" + + " 2: Only generate instance for Flash attention integration\n" + + " 4: Only generate instance for PyTorch integration\n" + + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration", ) parser.add_argument( "--optdim", - default='-1', + default="-1", required=False, - help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ - "eg. --optdim=32,64,128,256" + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + + "eg. --optdim=32,64,128,256", ) args = parser.parse_args() - api_list = args.direction.split(',') - filter_list = args.filter.split(',') - filter_list.extend([''] * (len(api_list) - len(filter_list))) - optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + api_list = args.direction.split(",") + filter_list = args.filter.split(",") + filter_list.extend([""] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(",")] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + list_blobs( + args.list_blobs, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) else: - write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + write_blobs( + args.output_dir, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 5f589db8d0..c90948db55 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -6,47 +6,50 @@ import argparse from enum import IntEnum from pathlib import Path import sys -from typing import List, Optional, Any +from typing import List, Any import functools import itertools import copy from dataclasses import dataclass -def get_if_str(idx, total, lase_else = True): + +def get_if_str(idx, total, lase_else=True): if idx == 0: - return 'if' + return "if" elif idx < total - 1: - return 'else if' + return "else if" else: if lase_else: - return 'else' + return "else" else: - return 'else if' + return "else if" -XBIAS_ENUM_STR_MAP = [ - 'no', - 'xbias'] # pre-norm add bias + +XBIAS_ENUM_STR_MAP = ["no", "xbias"] # pre-norm add bias FUSED_ADD_ENUM_STR_MAP = [ - 'no', - 'pras', # pre-norm - 'pra' ] # post-norm + "no", + "pras", # pre-norm + "pra", +] # post-norm -FUSED_FUSED_SWEEP_STR_MAP = [ - 'no', - 'dquant' ] +FUSED_FUSED_SWEEP_STR_MAP = ["no", "dquant"] + +DATA_TYPE_MAP = { + "fp32": "float", + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "int8": "ck_tile::int8_t", + "fp8": "ck_tile::fp8_t", +} -DATA_TYPE_MAP = {'fp32' : 'float', - 'fp16' : 'ck_tile::fp16_t', - 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t', - 'fp8' : 'ck_tile::fp8_t'} def BOOL_MAP(b_) -> str: if b_: - return 'true' + return "true" else: - return 'false' + return "false" + class layernorm_fwd_codegen: API_TRAITS_DEFINE = """ @@ -268,15 +271,15 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, """ - API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ + API_PER_DTYPE = """ {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ {F_per_n_case} }} """ - API_PER_N_CASE=""" {F_if} {F_N_COND} {{ + API_PER_N_CASE = """ {F_if} {F_N_COND} {{ {F_inner_dispatch} }} """ - API_INNER_CASE=""" {F_if} {F_VEC_COND} + API_INNER_CASE = """ {F_if} {F_VEC_COND} r={F_instance_func}(s, a); """ @@ -313,138 +316,141 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, @dataclass class k_traits: - F_kPadN : bool - F_kSaveMeanInvStd : bool - F_kTwoPass : bool - F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum - F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum - F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum + F_kPadN: bool + F_kSaveMeanInvStd: bool + F_kTwoPass: bool + F_kXbias: Any #: layernorm_fwd_codegen.k_bias_enum + F_kFusedAdd: Any #: layernorm_fwd_codegen.k_fuesd_add_enum + F_kFusedQuant: Any #: layernorm_fwd_codegen.k_fused_sweep_enum @dataclass class k_shape: - F_BlockTile : List[int] - F_WarpPerBlock : List[int] - F_WarpTile : List[int] - F_Vector_ : List[int] + F_BlockTile: List[int] + F_WarpPerBlock: List[int] + F_WarpTile: List[int] + F_Vector_: List[int] + @property def F_BlockSize(self) -> int: - return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + return functools.reduce(lambda a, b: a * b, self.F_WarpTile) @dataclass class k_problem: - F_XDataType : str - F_XBiasDataType : str - F_GammaDataType : str - F_BetaDataType : str - F_ComputeDataType : str - F_YDataType : str - F_MeanDataType : str - F_InvStdDataType : str - F_BlockShape : str - F_Traits : Any #k_traits + F_XDataType: str + F_XBiasDataType: str + F_GammaDataType: str + F_BetaDataType: str + F_ComputeDataType: str + F_YDataType: str + F_MeanDataType: str + F_InvStdDataType: str + F_BlockShape: str + F_Traits: Any # k_traits @dataclass class k_pipeline_one_pass: - F_Problem : Any #k_problem - + F_Problem: Any # k_problem + @dataclass class k_pipeline_two_pass: - F_Problem : Any #k_problem + F_Problem: Any # k_problem @dataclass class default_2d_epilogue_problem: - F_AccDataType : str - F_ODataType : str - F_kPadM : bool - F_kPadN : bool + F_AccDataType: str + F_ODataType: str + F_kPadM: bool + F_kPadN: bool @dataclass class default_2d_epilogue: - F_problem : Any + F_problem: Any @dataclass class k_kernel: - F_pipeline : Any - F_epilogue : Any + F_pipeline: Any + F_epilogue: Any @dataclass class h_traits: - F_XDataType : str - F_YDataType : str - F_SmoothScaleDataType : str - F_YScaleDataType : str - F_Repeat_M : int - F_Repeat_N : int - F_ThreadPerBlock_M : int - F_ThreadPerBlock_N : int - F_Vector_N : int - F_kPadN : bool - F_kSaveMeanInvStd_ : bool - F_kFastFDiv_ : bool - F_kWelford_ : bool - F_kTwoPass_ : bool - F_kXbias_ : int - F_kFusedAdd : int - F_kFusedQuant : int + F_XDataType: str + F_YDataType: str + F_SmoothScaleDataType: str + F_YScaleDataType: str + F_Repeat_M: int + F_Repeat_N: int + F_ThreadPerBlock_M: int + F_ThreadPerBlock_N: int + F_Vector_N: int + F_kPadN: bool + F_kSaveMeanInvStd_: bool + F_kFastFDiv_: bool + F_kWelford_: bool + F_kTwoPass_: bool + F_kXbias_: int + F_kFusedAdd: int + F_kFusedQuant: int @property - def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + def trait_name(self) -> str: + t_ = f"{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}" + t_ += f", {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}" + t_ += f", {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}" return t_ # string when calling this kernel @property def call_name(self) -> str: - return f'layernorm2d_fwd_>' + return f"layernorm2d_fwd_>" # string when define this kernel @property def def_name(self) -> str: - return f'template float layernorm2d_fwd_>(const S&, A);' + return f"template float layernorm2d_fwd_>(const S&, A);" # this class hold kernel under same source file @dataclass class h_instance: - F_DataTypePair : str - F_N : str - F_xbias : int - F_add : int - F_sweep : int - instance_list : List[Any] # List[h_traits] + F_DataTypePair: str + F_N: str + F_xbias: int + F_add: int + F_sweep: int + instance_list: List[Any] # List[h_traits] @property def name(self) -> str: - prec_i, prec_o = self.F_DataTypePair.split(',') - dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' - nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + prec_i, prec_o = self.F_DataTypePair.split(",") + dtype_str = f"{prec_i}" if prec_i == prec_o else f"{prec_i}_{prec_o}" + nnn = f"layernorm2d_fwd_{dtype_str}_n{self.F_N}" if self.F_xbias != 0: - nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias] + nnn = nnn + "_" + XBIAS_ENUM_STR_MAP[self.F_xbias] if self.F_add != 0: - nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + nnn = nnn + "_" + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: - nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + nnn = nnn + "_" + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] return nnn @property - def instance_name(self) ->str: + def instance_name(self) -> str: return self.name @property - def content(self) ->str: - instance_defs = '' + def content(self) -> str: + instance_defs = "" for ins in self.instance_list: - instance_defs += ins.def_name + '\n' - return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + instance_defs += ins.def_name + "\n" + return layernorm_fwd_codegen.INSTANCE_BASE.format( + F_instance_def=instance_defs + ) @property def name_api(self) -> str: - return 'layernorm2d_fwd_api' + return "layernorm2d_fwd_api" @property def name_common_header(self) -> str: - return 'layernorm2d_fwd_api_common' + return "layernorm2d_fwd_api_common" def content_api(self, args) -> str: # 1 sort based on dtype @@ -457,40 +463,64 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) - d_str = '' + d_str = "" for i_d, dtype_ in enumerate(t_dtype_dict): blob_per_t = t_dtype_dict[dtype_] - n_str = '' + n_str = "" for i_n, n_ in enumerate(blob_per_t): blob_per_n = blob_per_t[n_] inner_str = "" for i_b, b_ in enumerate(blob_per_n): # generate single kernel instance file - #vec_str = "" + # vec_str = "" for i_ins, ins in enumerate(b_.instance_list): idx_in_n = i_b * len(b_.instance_list) + i_ins len_in_n = len(blob_per_n) * len(b_.instance_list) # _if = 'if' if i_ins == 0 else 'else if' if ins.F_kFusedQuant == 0: - _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + _sweep_cond = "t.fused_quant == {f_fused_sweep}".format( + f_fused_sweep=ins.F_kFusedQuant + ) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == "{f_sx_type}" && t.prec_sy == "{f_sy_type}")'.format( + f_fused_sweep=ins.F_kFusedQuant, + f_sx_type=ins.F_SmoothScaleDataType, + f_sy_type=ins.F_YScaleDataType, + ) elif ins.F_kFusedQuant == 2: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) - _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( - f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd, - f_sweep_cond = _sweep_cond) - inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), - F_VEC_COND = _cond, F_instance_func=ins.call_name) - #inner_str = inner_str + vec_str - n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else '' - n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) - prec_i, prec_o = dtype_.split(',') - d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == "{f_sy_type}")'.format( + f_fused_sweep=ins.F_kFusedQuant, + f_sy_type=ins.F_YScaleDataType, + ) + _cond = "((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))".format( + f_vec_n=ins.F_Vector_N, + f_xbias=ins.F_kXbias, + f_fused_add=ins.F_kFusedAdd, + f_sweep_cond=_sweep_cond, + ) + inner_str += self.API_INNER_CASE.format( + F_if=get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND=_cond, + F_instance_func=ins.call_name, + ) + # inner_str = inner_str + vec_str + n_cnd = f"(a.n <= {n_})" if isinstance(n_, int) else "" + n_str += self.API_PER_N_CASE.format( + F_if=get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), + F_N_COND=n_cnd, + F_inner_dispatch=inner_str, + ) + prec_i, prec_o = dtype_.split(",") + d_str += self.API_PER_DTYPE.format( + F_if=get_if_str(i_d, len(t_dtype_dict), False), + F_i_type=prec_i, + F_o_type=prec_o, + F_per_n_case=n_str, + ) - api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + api_base = self.API_BASE.format( + F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str + ) return api_base @property @@ -501,83 +531,982 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_traits = layernorm_fwd_codegen.h_traits h_instance = layernorm_fwd_codegen.h_instance - dynamic_quant_out_dtype = ['int8', 'fp8'] + dynamic_quant_out_dtype = ["int8", "fp8"] # some predefined support range # (prec_i,prec_o) for simplicity this string will be used as key for dict - scale_list = [('fp32,fp32')] - dtype_list = [('fp16,fp16'), ('bf16,bf16'), - ('fp16,int8'), ('bf16,int8'), - ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out - types_8bit = ('int8', 'fp8') - types_16bit = ('int16', 'fp16', 'bf16') - #fused_add_list = [0, 1, 2] - #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + scale_list = [("fp32,fp32")] + dtype_list = [ + ("fp16,fp16"), + ("bf16,bf16"), + ("fp16,int8"), + ("bf16,int8"), + ("fp16,fp8"), + ("bf16,fp8"), + ] # NOTE: only fused-dynamic-quant use int8 or fp8 out + types_8bit = ("int8", "fp8") + types_16bit = ("int16", "fp16", "bf16") + # fused_add_list = [0, 1, 2] + # fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant xbias_list = [0, 1] fused_add_list = [0, 1] - fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant + fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} + h_trait_dict = { + "64": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 8, + 8, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 16, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 64, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "128": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 16, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 64, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 4, + 64, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "256": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 64, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 4, + 64, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 4, + 64, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "512": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 4, + 64, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 4, + 64, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 4, + 64, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 8, + 4, + 64, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "768": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 4, + 64, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 6, + 4, + 64, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 12, + 4, + 64, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "1024": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 2, + 128, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 2, + 128, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 2, + 128, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 256, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "1536": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 4, + 64, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 2, + 128, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 256, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 6, + 1, + 256, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "2048": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 1, + 256, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 1, + 256, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 256, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 8, + 1, + 256, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "3072": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 128, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 256, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 6, + 1, + 256, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 1024, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "4096": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 1, + 256, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 256, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 2, + 1, + 1024, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 1024, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "6144": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 256, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 512, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 3, + 1, + 1024, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 6, + 1, + 1024, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "8192": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 256, + 8, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 512, + 4, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 1024, + 2, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 8, + 1, + 1024, + 1, + True, + False, + True, + True, + False, + 0, + 0, + 0, + ), + ], + "big": [ + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 1, + 1, + 1024, + 8, + True, + False, + True, + True, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 256, + 4, + True, + False, + True, + True, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 12, + 1, + 256, + 2, + True, + False, + True, + True, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + 1, + 4, + 1, + 1024, + 1, + True, + False, + True, + True, + True, + 0, + 0, + 0, + ), + ], + } total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): - prec_i, prec_o = dtype.split(',') - scale_sm, scale_y = scale_type.split(',') + for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product( + dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list + ): + prec_i, prec_o = dtype.split(",") + scale_sm, scale_y = scale_type.split(",") if prec_o in dynamic_quant_out_dtype and fused_quant != 1: - continue # skip non dynamic quant case - if fused_quant == 1 and hs_key == 'big': + continue # skip non dynamic quant case + if fused_quant == 1 and hs_key == "big": continue current_hs = list() for chs_ in hs: - h_ = copy.copy(chs_) # copy the base instance out + h_ = copy.copy(chs_) # copy the base instance out h_.F_XDataType = prec_i h_.F_YDataType = prec_o h_.F_SmoothScaleDataType = scale_sm @@ -587,29 +1516,33 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_.F_kFusedQuant = fused_quant # disable welford update for 8bit and 16 bit smallN if not h_.F_kTwoPass_: - #disable 16 bit when set args disable_16b_welford + # disable 16 bit when set args disable_16b_welford if args.disable_16b_welford and prec_i in types_16bit: h_.F_kWelford_ = False - #disable 8bit by default + # disable 8bit by default elif prec_i in types_8bit or prec_o in types_8bit: h_.F_kWelford_ = False - #disable 16bit small N - elif prec_i in types_16bit and hs_key == '64': + # disable 16bit small N + elif prec_i in types_16bit and hs_key == "64": h_.F_kWelford_ = False - current_hs.append(h_) # + "\n" - #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ - current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs)) + current_hs.append(h_) # + "\n" + # f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = "big" if hs_key == "big" else current_n + total_blob.append( + h_instance( + dtype, current_n_str, xbias, fused_add, fused_quant, current_hs + ) + ) return total_blob def list_blobs(self, args) -> None: w_p = Path(self.working_path) - list_p = w_p / 'layernorm2d_fwd_blobs.txt' + list_p = w_p / "layernorm2d_fwd_blobs.txt" blobs = self.get_blobs(args) - with list_p.open('w') as list_f: + with list_p.open("w") as list_f: # api related file - list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") - list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") # kernel instance file for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") @@ -618,24 +1551,28 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, w_p = Path(self.working_path) w_str = self.content_api(args) (w_p / (self.name_api + ".cpp")).write_text(w_str) - (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + (w_p / (self.name_common_header + ".hpp")).write_text( + self.content_common_header + ) blobs = self.get_blobs(args) for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) + def list_blobs(args): - api_list = args.api.split(',') + api_list = args.api.split(",") for api in api_list: - if api == 'fwd': + if api == "fwd": layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args) def gen_blobs(args): - api_list = args.api.split(',') + api_list = args.api.split(",") for api in api_list: - if api == 'fwd': + if api == "fwd": layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args) + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", @@ -644,9 +1581,9 @@ if __name__ == "__main__": parser.add_argument( "-a", "--api", - default='fwd[all]', + default="fwd[all]", required=False, - help="supply API(s) to generate (default: fwd). separated by comma." + help="supply API(s) to generate (default: fwd). separated by comma.", ) # the directory for list_blobs/gen_blobs to write files into @@ -655,7 +1592,7 @@ if __name__ == "__main__": "--working_path", default="./", required=False, - help="the path where all the blobs are going to be generated" + help="the path where all the blobs are going to be generated", ) # this script have 2 modes @@ -667,15 +1604,15 @@ if __name__ == "__main__": parser.add_argument( "-l", "--list_blobs", - action='store_true', - help="list all the kernels to a file, " + action="store_true", + help="list all the kernels to a file, ", ) parser.add_argument( "-g", "--gen_blobs", - action='store_true', - help="generate all kernels into different tile" + action="store_true", + help="generate all kernels into different tile", ) # TODO: if using filter, must apply same value to output_dir and list_blobs @@ -683,7 +1620,7 @@ if __name__ == "__main__": "-f", "--filter", required=False, - help="filter out kernels that need to generate, using fnmatch module" + help="filter out kernels that need to generate, using fnmatch module", ) parser.add_argument( @@ -691,29 +1628,27 @@ if __name__ == "__main__": "--traits", default="all", required=False, - help="enable/disable some feature. default generate all" + help="enable/disable some feature. default generate all", ) parser.add_argument( - "-r", - "--receipt", - default=0, - required=False, - help="codegen receipt." + "-r", "--receipt", default=0, required=False, help="codegen receipt." ) parser.add_argument( "--disable_16b_welford", default=False, required=False, - help="enable/disable welford for 16bit datatype n > 64" + help="enable/disable welford for 16bit datatype n > 64", ) args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') - if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): - print('gen_blobs/list_blobs must specify only one option') + if (args.gen_blobs and args.list_blobs) or ( + (not args.gen_blobs) and (not args.list_blobs) + ): + print("gen_blobs/list_blobs must specify only one option") sys.exit() p = Path(args.working_path) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 75d7abd0ad..88e58aba5f 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -6,45 +6,51 @@ import argparse from enum import IntEnum from pathlib import Path import sys -from typing import List, Optional, Any +from typing import List, Any import functools import itertools import copy from dataclasses import dataclass -def get_if_str(idx, total, lase_else = True): +def get_if_str(idx, total, lase_else=True): if idx == 0: - return 'if' + return "if" elif idx < total - 1: - return 'else if' + return "else if" else: if lase_else: - return 'else' + return "else" else: - return 'else if' + return "else if" + FUSED_ADD_ENUM_STR_MAP = [ - 'no', - 'pras', # pre-norm - 'pra' ] # post-norm + "no", + "pras", # pre-norm + "pra", +] # post-norm FUSED_FUSED_SWEEP_STR_MAP = [ - 'no', - 'sdquant', # smooth dynamic quant - 'dquant' ] # dynamic quant (without sm_scale) + "no", + "sdquant", # smooth dynamic quant + "dquant", +] # dynamic quant (without sm_scale) + +DATA_TYPE_MAP = { + "fp32": "float", + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "int8": "ck_tile::int8_t", + "fp8": "ck_tile::fp8_t", +} -DATA_TYPE_MAP = {'fp32' : 'float', - 'fp16' : 'ck_tile::fp16_t', - 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t', - 'fp8' : 'ck_tile::fp8_t'} def BOOL_MAP(b_) -> str: if b_: - return 'true' + return "true" else: - return 'false' + return "false" class rmsnorm_fwd_codegen: @@ -326,139 +332,142 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, @dataclass class k_traits: - F_kPadN : bool - F_kSaveMeanInvStd : bool - F_kTwoPass : bool - F_kFusedAdd : Any - F_kFusedQuant : Any + F_kPadN: bool + F_kSaveMeanInvStd: bool + F_kTwoPass: bool + F_kFusedAdd: Any + F_kFusedQuant: Any @dataclass class k_shape: - F_BlockTile : List[int] - F_WarpPerBlock : List[int] - F_WarpTile : List[int] - F_Vector_ : List[int] + F_BlockTile: List[int] + F_WarpPerBlock: List[int] + F_WarpTile: List[int] + F_Vector_: List[int] + @property def F_BlockSize(self) -> int: - return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + return functools.reduce(lambda a, b: a * b, self.F_WarpTile) @dataclass class k_problem: - F_XDataType : str - F_GammaDataType : str - F_ComputeDataType : str - F_YDataType : str - F_InvRmsDataType : str - F_BlockShape : str - F_Traits : Any #k_traits + F_XDataType: str + F_GammaDataType: str + F_ComputeDataType: str + F_YDataType: str + F_InvRmsDataType: str + F_BlockShape: str + F_Traits: Any # k_traits @dataclass class k_pipeline_one_pass: - F_Problem : Any #k_problem + F_Problem: Any # k_problem @dataclass class k_pipeline_two_pass: - F_Problem : Any #k_problem + F_Problem: Any # k_problem @dataclass class default_2d_epilogue_problem: - F_AccDataType : str - F_ODataType : str - F_kPadM : bool - F_kPadN : bool + F_AccDataType: str + F_ODataType: str + F_kPadM: bool + F_kPadN: bool @dataclass class default_2d_epilogue: - F_problem : Any + F_problem: Any @dataclass class k_kernel: - F_pipeline : Any - F_epilogue : Any + F_pipeline: Any + F_epilogue: Any @dataclass class h_traits: - F_XDataType : str - F_YDataType : str - F_SmoothScaleDataType : str - F_YScaleDataType : str - F_UnquantYDataType : str - F_Repeat_M : int - F_Repeat_N : int - F_ThreadPerBlock_M : int - F_ThreadPerBlock_N : int - F_Vector_N : int - F_kPadN : bool - F_kSaveInvRms : bool + F_XDataType: str + F_YDataType: str + F_SmoothScaleDataType: str + F_YScaleDataType: str + F_UnquantYDataType: str + F_Repeat_M: int + F_Repeat_N: int + F_ThreadPerBlock_M: int + F_ThreadPerBlock_N: int + F_Vector_N: int + F_kPadN: bool + F_kSaveInvRms: bool F_kSaveUnquant: bool - F_kTwoPass : bool - F_kFusedAdd : int - F_kFusedQuant : int - F_use_model_sensitive_rmsnorm : int + F_kTwoPass: bool + F_kFusedAdd: int + F_kFusedQuant: int + F_use_model_sensitive_rmsnorm: int @property - def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}' + def trait_name(self) -> str: + t_ = f"{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}" + t_ += f", {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}" + t_ += f", {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}" return t_ # string when calling this kernel @property def call_name(self) -> str: - return f'rmsnorm2d_fwd_>' + return f"rmsnorm2d_fwd_>" # string when define this kernel @property def def_name(self) -> str: - return f'template float rmsnorm2d_fwd_>(const S&, A);' + return f"template float rmsnorm2d_fwd_>(const S&, A);" # this class hold kernel under same source file @dataclass class h_instance: - F_DataTypePair : str - F_N : str - F_add : int - F_sweep : int - F_saveunquant : bool - F_use_model_sensitive_rmsnorm : int - instance_list : List[Any] # List[h_traits] + F_DataTypePair: str + F_N: str + F_add: int + F_sweep: int + F_saveunquant: bool + F_use_model_sensitive_rmsnorm: int + instance_list: List[Any] # List[h_traits] @property def name(self) -> str: - prec_i, prec_o = self.F_DataTypePair.split(',') - dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' - nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}' + prec_i, prec_o = self.F_DataTypePair.split(",") + dtype_str = f"{prec_i}" if prec_i == prec_o else f"{prec_i}_{prec_o}" + nnn = f"rmsnorm2d_fwd_{dtype_str}_n{self.F_N}" if self.F_add != 0: - nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + nnn = nnn + "_" + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: - nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + nnn = nnn + "_" + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] if self.F_saveunquant: - nnn = nnn + '_saveunquant' + nnn = nnn + "_saveunquant" if self.F_use_model_sensitive_rmsnorm == 0: - nnn = nnn + '_nsm' + nnn = nnn + "_nsm" elif self.F_use_model_sensitive_rmsnorm == 1: - nnn = nnn + '_t5ml' + nnn = nnn + "_t5ml" return nnn @property - def instance_name(self) ->str: + def instance_name(self) -> str: return self.name @property - def content(self) ->str: - instance_defs = '' + def content(self) -> str: + instance_defs = "" for ins in self.instance_list: - instance_defs += ins.def_name + '\n' - return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + instance_defs += ins.def_name + "\n" + return rmsnorm_fwd_codegen.INSTANCE_BASE.format( + F_instance_def=instance_defs + ) @property def name_api(self) -> str: - return 'rmsnorm2d_fwd_api' + return "rmsnorm2d_fwd_api" @property def name_common_header(self) -> str: - return 'rmsnorm2d_fwd_api_common' + return "rmsnorm2d_fwd_api_common" @property def content_api(self) -> str: @@ -472,40 +481,66 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) - d_str = '' + d_str = "" for i_d, dtype_ in enumerate(t_dtype_dict): blob_per_t = t_dtype_dict[dtype_] - n_str = '' + n_str = "" for i_n, n_ in enumerate(blob_per_t): blob_per_n = blob_per_t[n_] inner_str = "" for i_b, b_ in enumerate(blob_per_n): # generate single kernel instance file - #vec_str = "" + # vec_str = "" for i_ins, ins in enumerate(b_.instance_list): idx_in_n = i_b * len(b_.instance_list) + i_ins len_in_n = len(blob_per_n) * len(b_.instance_list) # _if = 'if' if i_ins == 0 else 'else if' if ins.F_kFusedQuant == 0: - _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + _sweep_cond = "t.fused_quant == {f_fused_sweep}".format( + f_fused_sweep=ins.F_kFusedQuant + ) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == "{f_sx_type}" && t.prec_sy == "{f_sy_type}" && t.save_unquant == {f_suq})'.format( + f_fused_sweep=ins.F_kFusedQuant, + f_sx_type=ins.F_SmoothScaleDataType, + f_sy_type=ins.F_YScaleDataType, + f_suq=BOOL_MAP(ins.F_kSaveUnquant), + ) elif ins.F_kFusedQuant == 2: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) - _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format( - f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, - f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm) - inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), - F_VEC_COND = _cond, F_instance_func=ins.call_name) - #inner_str = inner_str + vec_str - n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' - n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) - prec_i, prec_o = dtype_.split(',') - d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == "{f_sy_type}" && t.save_unquant == {f_suq})'.format( + f_fused_sweep=ins.F_kFusedQuant, + f_sy_type=ins.F_YScaleDataType, + f_suq=BOOL_MAP(ins.F_kSaveUnquant), + ) + _cond = "((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )".format( + f_vec_n=ins.F_Vector_N, + f_fused_add=ins.F_kFusedAdd, + f_sweep_cond=_sweep_cond, + f_use_model_sensitive_rmsnorm=ins.F_use_model_sensitive_rmsnorm, + ) + inner_str += self.API_INNER_CASE.format( + F_if=get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND=_cond, + F_instance_func=ins.call_name, + ) + # inner_str = inner_str + vec_str + n_cnd = f"(a.n <= {n_})" if (i_n < len(blob_per_t) - 1) else "" + n_str += self.API_PER_N_CASE.format( + F_if=get_if_str(i_n, len(blob_per_t)), + F_N_COND=n_cnd, + F_inner_dispatch=inner_str, + ) + prec_i, prec_o = dtype_.split(",") + d_str += self.API_PER_DTYPE.format( + F_if=get_if_str(i_d, len(t_dtype_dict), False), + F_i_type=prec_i, + F_o_type=prec_o, + F_per_n_case=n_str, + ) - api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + api_base = self.API_BASE.format( + F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str + ) return api_base @property @@ -516,150 +551,2081 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_traits = rmsnorm_fwd_codegen.h_traits h_instance = rmsnorm_fwd_codegen.h_instance - dynamic_quant_out_dtype = ['int8', 'fp8'] + dynamic_quant_out_dtype = ["int8", "fp8"] # some predefined support range # (prec_i,prec_o) for simplicity this string will be used as key for dict - scale_list = [('fp32,fp32')] - dtype_list = [('fp16,fp16'), ('bf16,bf16'), - ('fp16,int8'), ('bf16,int8'), - ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out - #fused_add_list = [0, 1, 2] - #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + scale_list = [("fp32,fp32")] + dtype_list = [ + ("fp16,fp16"), + ("bf16,bf16"), + ("fp16,int8"), + ("bf16,int8"), + ("fp16,fp8"), + ("bf16,fp8"), + ] # NOTE: only fused-dynamic-quant use int8 out + # fused_add_list = [0, 1, 2] + # fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant fused_add_list = [0, 1] - fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + fused_sweep_list = [ + 0, + 1, + 2, + ] # NOTE: only single pass can use fused (smooth) dynamic quant bool_list = [False, True] h_trait_dicts = { 0: { # rm rn tm tn vn pd mv unquant 2p add sweep srm - '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)], - '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)], - '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)], - '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)], - '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)], - '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)] + "64": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 8, + 8, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 16, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "128": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 16, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "256": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "512": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "640": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 5, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 5, + 4, + 128, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "768": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 12, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "1024": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 2, + 64, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 2, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 2, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "1536": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 4, + 64, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 2, + 128, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "2048": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "3072": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 128, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "4096": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "6144": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 512, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "8192": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 512, + 4, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 0, + ), + ], + "big": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 1, + 1024, + 8, + True, + False, + False, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 4, + True, + False, + False, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 12, + 1, + 256, + 2, + True, + False, + False, + True, + 0, + 0, + 0, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 1, + True, + False, + False, + True, + 0, + 0, + 0, + ), + ], }, 1: { # rm rn tm tn vn pd mv unquant 2p add sweep srm - '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)], - '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)], - '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)], - '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)], - '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)], - '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] - } + "64": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 8, + 8, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 16, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "128": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 16, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "256": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 8, + 32, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "512": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 4, + 64, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "640": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 2, + 128, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 5, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 5, + 4, + 128, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "768": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 2, + 128, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 4, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 4, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 12, + 4, + 64, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "1024": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 2, + 128, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 2, + 64, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 2, + 64, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "1536": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 2, + 128, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "2048": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 1, + 256, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "3072": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 256, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "4096": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 2, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "6144": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 512, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 3, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 6, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "8192": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 8, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 512, + 4, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 2, + True, + False, + False, + False, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 8, + 1, + 1024, + 1, + True, + False, + False, + False, + 0, + 0, + 1, + ), + ], + "big": [ + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 1, + 1, + 1024, + 8, + True, + False, + False, + True, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 256, + 4, + True, + False, + False, + True, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 12, + 1, + 256, + 2, + True, + False, + False, + True, + 0, + 0, + 1, + ), + h_traits( + "x", + "y", + "xs", + "ys", + "uqy", + 1, + 4, + 1, + 1024, + 1, + True, + False, + False, + True, + 0, + 0, + 1, + ), + ], + }, } total_blob = list() - for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive + for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: hs = current_trait_dict[hs_key] current_n = hs_key - for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): - prec_i, prec_o = dtype.split(',') - scale_sm, scale_y = scale_type.split(',') - if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: - continue # skip non dynamic quant case - if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': + for ( + dtype, + scale_type, + fused_add, + fused_quant, + save_unquant, + ) in itertools.product( + dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list + ): + prec_i, prec_o = dtype.split(",") + scale_sm, scale_y = scale_type.split(",") + if ( + prec_o in dynamic_quant_out_dtype + and fused_quant != 1 + and fused_quant != 2 + ): + continue # skip non dynamic quant case + if (fused_quant == 1 or fused_quant == 2) and hs_key == "big": continue - if (fused_quant == 0 and save_unquant == True): - continue # save_unquant should always be false when there is no quant enabled + if fused_quant == 0 and save_unquant: + continue # save_unquant should always be false when there is no quant enabled current_hs = list() for chs_ in hs: - h_ = copy.copy(chs_) # copy the base instance out + h_ = copy.copy(chs_) # copy the base instance out h_.F_XDataType = prec_i h_.F_YDataType = prec_o h_.F_SmoothScaleDataType = scale_sm @@ -668,20 +2634,30 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant h_.F_kSaveUnquant = save_unquant - current_hs.append(h_) # + "\n" - #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ - current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs)) + current_hs.append(h_) # + "\n" + # f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = "big" if hs_key == "big" else current_n + total_blob.append( + h_instance( + dtype, + current_n_str, + fused_add, + fused_quant, + save_unquant, + h_.F_use_model_sensitive_rmsnorm, + current_hs, + ) + ) return total_blob def list_blobs(self) -> None: w_p = Path(self.working_path) - list_p = w_p / 'rmsnorm2d_fwd_blobs.txt' + list_p = w_p / "rmsnorm2d_fwd_blobs.txt" blobs = self.get_blobs() - with list_p.open('w') as list_f: + with list_p.open("w") as list_f: # api related file - list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") - list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") # kernel instance file for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") @@ -689,23 +2665,25 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, def gen_blobs(self) -> None: w_p = Path(self.working_path) (w_p / (self.name_api + ".cpp")).write_text(self.content_api) - (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + (w_p / (self.name_common_header + ".hpp")).write_text( + self.content_common_header + ) blobs = self.get_blobs() for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) def list_blobs(args): - api_list = args.api.split(',') + api_list = args.api.split(",") for api in api_list: - if api == 'fwd': + if api == "fwd": rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs() def gen_blobs(args): - api_list = args.api.split(',') + api_list = args.api.split(",") for api in api_list: - if api == 'fwd': + if api == "fwd": rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs() @@ -717,9 +2695,9 @@ if __name__ == "__main__": parser.add_argument( "-a", "--api", - default='fwd[all]', + default="fwd[all]", required=False, - help="supply API(s) to generate (default: fwd). separated by comma." + help="supply API(s) to generate (default: fwd). separated by comma.", ) # the directory for list_blobs/gen_blobs to write files into @@ -728,7 +2706,7 @@ if __name__ == "__main__": "--working_path", default="./", required=False, - help="the path where all the blobs are going to be generated" + help="the path where all the blobs are going to be generated", ) # this script have 2 modes @@ -740,15 +2718,15 @@ if __name__ == "__main__": parser.add_argument( "-l", "--list_blobs", - action='store_true', - help="list all the kernels to a file, " + action="store_true", + help="list all the kernels to a file, ", ) parser.add_argument( "-g", "--gen_blobs", - action='store_true', - help="generate all kernels into different tile" + action="store_true", + help="generate all kernels into different tile", ) # TODO: if using filter, must apply same value to output_dir and list_blobs @@ -756,7 +2734,7 @@ if __name__ == "__main__": "-f", "--filter", required=False, - help="filter out kernels that need to generate, using fnmatch module" + help="filter out kernels that need to generate, using fnmatch module", ) parser.add_argument( @@ -764,22 +2742,20 @@ if __name__ == "__main__": "--traits", default="all", required=False, - help="enable/disable some feature. default generate all" + help="enable/disable some feature. default generate all", ) parser.add_argument( - "-r", - "--receipt", - default=0, - required=False, - help="codegen receipt." + "-r", "--receipt", default=0, required=False, help="codegen receipt." ) args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') - if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): - print('gen_blobs/list_blobs must specify only one option') + if (args.gen_blobs and args.list_blobs) or ( + (not args.gen_blobs) and (not args.list_blobs) + ): + print("gen_blobs/list_blobs must specify only one option") sys.exit() p = Path(args.working_path) diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 64c9dda64a..3b4258d8b1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -28,7 +28,8 @@ template + typename CDataType, + ck_tile::QuantType QuantMode> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) @@ -44,19 +45,20 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; float ave_time{0}; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 93e461b9d3..bc271ac38e 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -11,12 +11,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif template constexpr ck_tile::index_t get_k_warp_tile() @@ -66,7 +60,6 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; }; template @@ -102,15 +95,6 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; -}; - using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -119,7 +103,12 @@ auto create_args(int argc, char* argv[]) arg_parser.insert("Ms", "", "M dimensions - empty by default.") .insert("Ns", "", "N dimensions - empty by default.") .insert("Ks", "", "K dimensions - empty by default.") - .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") @@ -132,7 +121,9 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") - .insert("kbatch", "1", "kbatch for SplitK"); + .insert("kbatch", "1", "kbatch for SplitK") + .insert("quant_mode", "tensor", "Choose tensor (default), or rowcol"); + ; bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -145,13 +136,17 @@ inline std::size_t get_workspace_size(const std::vector& gem template + typename CDataType, + ck_tile::QuantType QuantMode> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 10d317a2c7..19211ed494 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -43,6 +43,7 @@ template float invoke_gemm(int n_warmup, int n_repeat, @@ -102,9 +103,10 @@ float invoke_gemm(int n_warmup, BDataType, BQDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr); + CDataType, + QuantMode>(stream, group_count, kargs_ptr); - std::string op_name{"Grouped Gemm"}; + std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; std::size_t flop = 0, num_btype = 0; for(int j = 0; j < group_count; ++j) @@ -132,6 +134,7 @@ template (group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); @@ -180,7 +183,8 @@ int run_grouped_gemm_example_with_layouts(int argc, ck_tile::index_t AQK, BQK; - if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; @@ -242,25 +246,49 @@ int run_grouped_gemm_example_with_layouts(int argc, const ck_tile::index_t M = Ms[i]; const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) + { + AQK = 1; // Row quantization: tensor shape [M, 1] or [1] + BQK = 1; // Column quantization: tensor shape [1, N] or [1] + } - AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT - BQK = N; // Column quantization: tensor shape [1, N]. Only for NT + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, 1, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = + ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout)); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] + stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] + } - stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); - stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); - stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); - stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); - stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout)); a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); b_k_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); c_m_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); - aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); - bq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout)))); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(1, 1, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc @@ -324,7 +352,8 @@ int run_grouped_gemm_example_with_layouts(int argc, AQLayout, BLayout, BQLayout, - CLayout>(warmup, repeat, group_count, gemm_descs); + CLayout, + QuantMode>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { @@ -339,13 +368,33 @@ int run_grouped_gemm_example_with_layouts(int argc, ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); - ck_tile::reference_gemm_rowcol_quant( - a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::reference_gemm_rowcol_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + ck_tile::reference_gemm_tensor_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = @@ -367,7 +416,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -388,7 +437,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") @@ -399,8 +449,9 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( - argc, argv, Row{}, Row{}, Row{}, Row{}, Row{}); + AccDataType, + QuantMode>( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { @@ -410,7 +461,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") @@ -421,7 +473,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Col{}, Col{}, Col{}, Col{}, Row{}); } else @@ -442,11 +495,28 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); if(data_type == "fp8") { - return run_gemm_example_prec_type, ck_tile::fp8_t>( - a_layout, b_layout, argc, argv); + if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::RowColQuant>( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } } else { diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index f822c7d8a7..dbdbe80c5d 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -143,7 +143,7 @@ int run_grouped_gemm_example_with_layouts(int argc, auto [result, arg_parser] = create_args(argc, argv); auto valid_input_data = [&](int group_count, const auto&... args) { - return !(args.empty() || ...) && group_count == (args.size() == ...); + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index db66d9a54b..1abb541e65 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -159,7 +159,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, using DsDataType = ck_tile::tuple; auto valid_input_data = [&](int group_count, const auto&... args) { - return !(args.empty() || ...) && group_count == (args.size() == ...); + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 6d6b71ea18..1641549c98 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,6 +1,32 @@ -add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) +set(SUPPORTED_GPUS gfx908 gfx90a gfx942 gfx950) + +set(has_supported_gpu FALSE) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST SUPPORTED_GPUS) + set(has_supported_gpu TRUE) + break() + endif() +endforeach() + +if(has_supported_gpu) + add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp) + add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) + add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) + add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) + + set(EXAMPLE_FLATMM_COMPILE_OPTIONS) + set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) + + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + + target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + +endif() -set(EXAMPLE_FLATMM_COMPILE_OPTIONS) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 3273fac674..9155b27dba 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -11,7 +11,102 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -#include "run_flatmm_example.inc" +#include + +template +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else + { + return "unknown"; + } +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); +} + +template +auto shuffle_b_v1(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, + FlatmmConfig::N_Warp, + FlatmmConfig::N_Warp_Tile, + NRepeat, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} template -float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s) +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) { using CodegenFlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,14 +178,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem; + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; @@ -110,7 +208,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, memory_operation, - FlatmmConfig::NumWaveGroups>>; + FlatmmConfig::NumWaveGroups, + false, + 1, + FlatmmConfig::TiledMMAPermuteN>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. @@ -118,8 +219,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -167,40 +268,145 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - return ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - return ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; } +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + ScaleM scale_m, + ScaleN scale_n, + int n_warmup, + int n_repeat) +{ + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_m, + scale_n}; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") + .insert("persistent", "0", "0: no persistent, 1: persistent kernel") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +#include "run_flatmm_example.inc" + template