mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Support WMMA (gfx12) in FMHA (#2528)
* Pass hdim to tile_example_fmha_fwd in fp8 tests
* Add WMMA support to fwd FMHA pipelines
* Tune tile sizes a bit for less spilling
fp16 256 is still quite slow
* Fix Q grad tile distribution for warp size = 32 and hdim >= 256
With AccDataType = float and warp size = 32, K0 becomes 0, K repeat is required to correcty distribute the tile.
* Use code based on BlockDropout in BlockDropoutBwd
* Fix split KV combine kernel for gfx12 (warp size 32) and make it more universal
* Fix LSE LDS tensor descriptors: kMaxSplits and kM0 were swapped, it worked on gfx9
because they both equal to 8 while on gfx12 they are 8 and 4;
* Fix Oacc LDS tensor descriptor: it was transposed even though its shape=[4 * kM0, kN1],
it worked on gfx9 because 4 * kM == kN1 == 32;
* Removing these hidden dependecies allows to support:
* any number of warps (power-of-2), not only 4;
* kN1 = 16, not only 32;
* any number of splits;
* Rename ids like o_acc_4 and Oacc4 to eliminate confusion: kNumWarps doesn't have to be 4 now
* Replace hard-coded kN1 in dispatch code with the requested tile size
* Add gfx12-specific tile sizes for split KV
* Pass GPU architecture to kernel generation scripts
This is still a temporary solution.
* Build and run FMHA CI tests for gfx12
* Fix issue after merging
* Fix bwd tile sizes
The current pipelines always read only one tile K and V tile, this
requires bk0 == bhdq and bk2 == bhdv (kK0 == kQKHeaddim and
kK2 == kVHeaddim).
* Use hardware f32->f8 on gfx12, remove v_perm
__builtin_amdgcn_perm is not needed because
__builtin_amdgcn_cvt_pk_fp8_f32 allows to specify which word (16 bit of
32-bit dword) is used to store results (two f8 values).
* Update changelog
* Add WMMA support to pagedkv
* Fix scripts after rebasing
* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout
Add comments with dropout implementation details
Fix performance regression of fwd+dropout
* Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
* "scalarize" seed and offset, they may come either from kernel args or from device memory
(presumably loaded with vector loads).
These changes help the compiler to procude more optimal code and reduce register spilling.
Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding
Use code based on BlockDropout in BlockDropoutBwd
Refactor BlockDropout (fwd)
Implement BlockDropout (fwd) for WMMA
Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
this version supports 16x16 tiles.
If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
to BlockDropoutBwd.
Implement BlockDropoutBwd for WMMA
Remove MakeRandValLds* functions unused in BlockDropoutBwd
Remove unused Run overload from BlockDropoutBwd
* Fix regression with philox seed and offset when they exceed 32-bit int
__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.
* Fix names after cherry-picking
* Fix selection of a fallback tile based on bm0
The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.
* Do not use filters related to qr_async_trload
They disable tiles/pipelines which are valid for gfx12.
* Use different dstr encoding when C is transposed
* Do not call GetQKBlockGemm (and hence WarpGemmDispatcher) in host code
Some WarpGemmDispatcher instantiations are defined only
for specific archs and undefined on host.
Calculations related to sched barriers are moved from Pipeline's public
fields into pipeline's operator().
* Fix incorrect name WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
Correct name is WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
because it's 32x32x16 with IterateK = 2 so K = 32, also all tiles used
in codegen scripts are 32, 32, 32.
* Generalize usages of WarpGemmDispatcher for MFMA and WMMA
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution is still
used explicitly becaus of swizzle factor = 4.
* Mark has_load_tr as maybe_unused
There are no transpose loading for RDNA.
* Remove CK_TILE_USE_MFMA/WMMA from fmha-related code
* Detect BlockSize on host based on warp size of the current device
If kBlockSize == kNumWarps * get_warp_size(), the kernel is launched with
kBlockSize / 2 because on host get_warp_size() == 64 always.
* Fix calculation of grid size for combine kernel with warp size = 32
* Add missing includes and header
* Support multiple archs in one binary for fwd
* Support multiple archs in one binary for fwd_splitkv, fwd_appendkv, pagedkv_prefill
* Support multiple archs in one binary for bwd
* trload kernels are compiled only for gfx950;
* instances with padding are checked after instances without padding so
they can be used as fallbacks (similarly to fwd);
* Extract common code from register_traits
* Revert "Fix regression with philox seed and offset when they exceed 32-bit int"
To simplify merging , the proper fix is in develop already.
* Support new numerical d paddings in trait ordering checks
* Build fp32 tests only on gfx9
* Do not use hardcoded M0 = 64 for dot bwd kernel
* Use textwrap.indent from standard library
* Make fp8 pipelines on gfx12 consistent with gfx9
* Update tests for current pipelines
* Make ninja check more responsive in CI
ninja buffers output so this job looks hanging.
* Support fp8fp32 by limiting O vector size
The fp32 output type requires storing 8 * sizeof(float) = 32 bytes,
which is not implemented (here 8 is the number of C values per lane for
v_wmma_f32_16x16x16...).
* Remove unused cmake options
* Unify including amd_buffer_addressing.hpp/_builtins.hpp
* Temporarily use amd_buffer_addressing.hpp on >=gfx10
amd_buffer_addressing_builtins.hpp uses inline asm for loads/stores
which is not compatible with >=gfx10:
* 1 scalar for exec masks instead of 2,
* gfx12 uses different instruction names etc.
* Update asm in bf16 conversions to work with warp 32
* Do not generate splitkv/appendkv with vlayout=col for consistency with fwd
* Add arch tags to kernels/host funcs, compile for each arch separately
* Add kM0 to fmha_bwd_dot_do_o kernel name to match filename
* Add workaround for miscompilation of bwd with padded hdim
SWDEV-559729: v_wmma instructions can be incorrectly placed in divergent
branches used to store padded tensors (when some lanes are inactive due
to padding). Inline asm with dummy dependencies on VGPRs of the tensors
prevents the compiler doing this.
* Fix add_gtest_executable for absolute paths
Some tests (like gemm_tile_engine) pass absolute paths to source files.
In CI the branch name is a part of the root dir, and if the branch name
contains "wmma", "xdl" etc., files can be incorrectly excluded.
* Run only hdim 128 smoke tests for fp8fp32
There are no instances for hdim 64 and 256.
* Format py with ruff to simplify merging develop
* Fix incorrect var name
* Codegen for gfx9,gfx950 when --targets is not specified
Aiter and Pytorch require changes for passing their targets to the codegen scripts.
With this temporary solution the files are generated but not all of them
have to be really built (depending on the used --offload-arch=).
* Combine arch-related values into ArchTrait
This more centralized approach removes duplication of various formatting templates.
* Try a workaround for Jenkins error "groovyjarjarasm.asm.MethodTooLargeException: Method too large"
Some code is extracted into a function.
[ROCm/composable_kernel commit: 1e77695fe8]
This commit is contained in:
@@ -19,12 +19,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for f32 to FMHA (fwd/bwd).
|
||||
* Added tensor-wise quantization for CK_TILE GEMM.
|
||||
* Added support for batched contraction kernel.
|
||||
* Added WMMA (gfx12) support for FMHA.
|
||||
* Added pooling kernel in CK_TILE
|
||||
* Added top-k sigmoid kernel in CK_TILE
|
||||
|
||||
### Changed
|
||||
|
||||
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
|
||||
* Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures.
|
||||
* FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time.
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 7.1.0
|
||||
|
||||
|
||||
@@ -346,7 +346,6 @@ endif()
|
||||
|
||||
|
||||
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
|
||||
option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF)
|
||||
option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF)
|
||||
|
||||
@@ -617,11 +616,11 @@ endif()
|
||||
|
||||
if(NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL)
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST" USES_TERMINAL)
|
||||
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST" USES_TERMINAL)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
76
Jenkinsfile
vendored
76
Jenkinsfile
vendored
@@ -42,6 +42,34 @@ def checkForPattern(pattern, log) {
|
||||
return [found: false, matchedLine: "", context: ""]
|
||||
}
|
||||
|
||||
// Scan build logs and send notifications
|
||||
def sendFailureNotifications() {
|
||||
// Get the build log.
|
||||
def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true)
|
||||
// Check for patterns in the log.
|
||||
def foundPatterns = []
|
||||
for (patternMap in failurePatterns) {
|
||||
def result = checkForPattern(patternMap.pattern, buildLog)
|
||||
if (result.found) {
|
||||
foundPatterns.add([
|
||||
description: patternMap.description,
|
||||
matchedLine: result.matchedLine,
|
||||
context: result.context
|
||||
])
|
||||
}
|
||||
}
|
||||
// Send a notification for each matched failure pattern.
|
||||
for (patternMap in foundPatterns) {
|
||||
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
|
||||
sh '''
|
||||
curl -X POST "${WEBHOOK_URL}" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}'
|
||||
'''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class Version {
|
||||
int major, minor, patch
|
||||
@Override
|
||||
@@ -1557,6 +1585,25 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run CK_TILE_FMHA Tests on gfx1201")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx12-generic && \
|
||||
make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \
|
||||
cd ../ &&
|
||||
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx1201 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests")
|
||||
@@ -1863,7 +1910,7 @@ pipeline {
|
||||
}
|
||||
agent{ label 'miopen && (gfx1101 || gfx1100)' }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx11-generic" \
|
||||
@@ -1884,7 +1931,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx12-generic" \
|
||||
@@ -1948,30 +1995,7 @@ pipeline {
|
||||
failure {
|
||||
node(rocmnode("nogpu")) {
|
||||
script {
|
||||
// Get the build log.
|
||||
def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true)
|
||||
// Check for patterns in the log.
|
||||
def foundPatterns = []
|
||||
for (patternMap in failurePatterns) {
|
||||
def result = checkForPattern(patternMap.pattern, buildLog)
|
||||
if (result.found) {
|
||||
foundPatterns.add([
|
||||
description: patternMap.description,
|
||||
matchedLine: result.matchedLine,
|
||||
context: result.context
|
||||
])
|
||||
}
|
||||
}
|
||||
// Send a notification for each matched failure pattern.
|
||||
for (patternMap in foundPatterns) {
|
||||
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
|
||||
sh '''
|
||||
curl -X POST "${WEBHOOK_URL}" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}'
|
||||
'''
|
||||
}
|
||||
}
|
||||
sendFailureNotifications()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
# Currently only gfx9 and gfx12 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
@@ -12,6 +12,7 @@ set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(BUILD_TESTING)
|
||||
# Build instances of all APIs for tests
|
||||
message(DEBUG "Enabling all FWD APIs of CK Tile FMHA for because testing is enabled")
|
||||
set(FMHA_FWD_ENABLE_APIS "all")
|
||||
endif()
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
@@ -36,15 +37,19 @@ file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
list(JOIN INST_TARGETS , FMHA_TARGETS_ARG)
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api bwd
|
||||
--receipt 3
|
||||
--optdim 32,64,96,128,256
|
||||
@@ -67,7 +72,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
@@ -76,7 +81,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -89,6 +94,7 @@ add_custom_command(
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile FMHA FWD kernels"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
@@ -96,6 +102,7 @@ add_custom_command(
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile FMHA BWD kernels"
|
||||
)
|
||||
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
|
||||
42
example/ck_tile/01_fmha/codegen/arch.py
Normal file
42
example/ck_tile/01_fmha/codegen/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchTrait:
|
||||
name: str
|
||||
preprocessor_check: str = field(default=None)
|
||||
device_name_check: str = field(default=None)
|
||||
tag: str = field(default=None)
|
||||
filename_suffix: str = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.preprocessor_check is None:
|
||||
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
||||
if self.device_name_check is None:
|
||||
object.__setattr__(
|
||||
self,
|
||||
"device_name_check",
|
||||
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
||||
)
|
||||
if self.tag is None:
|
||||
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
||||
if self.filename_suffix is None:
|
||||
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
||||
|
||||
|
||||
def get_factories_for_targets(
|
||||
targets: List[str], get_factory: Callable[[str], Any]
|
||||
) -> List[Any]:
|
||||
factories = dict()
|
||||
for target in targets:
|
||||
factory = get_factory(target)
|
||||
factories[factory.arch.name] = factory
|
||||
# Place more specific architectures first
|
||||
factories = sorted(
|
||||
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
||||
)
|
||||
return factories
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
@@ -21,6 +21,7 @@ from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
@@ -441,7 +442,7 @@ class FmhaFwdApiPool:
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
per_dtypes += " (void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
|
||||
|
||||
|
||||
@@ -720,15 +721,20 @@ def get_fwd_blobs(
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
@@ -737,7 +743,12 @@ def write_blobs(
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Literal, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
get_mask_check_map,
|
||||
@@ -22,16 +23,20 @@ from codegen.cpp_symbol_map import (
|
||||
BWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "fmha_bwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::
|
||||
@@ -132,10 +137,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_maxq},
|
||||
{F_bn0}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -144,67 +147,68 @@ float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
|
||||
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API = """
|
||||
#include <iostream>
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_, typename Arch>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
}}
|
||||
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
@@ -212,23 +216,22 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
"""
|
||||
|
||||
|
||||
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_=0) -> str:
|
||||
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str:
|
||||
lines = [
|
||||
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
|
||||
f"{if_(if_i)}({F_cond})",
|
||||
"{",
|
||||
*[" " + line for line in F_body.split("\n") if line.strip() != ""],
|
||||
indent(F_body),
|
||||
"}",
|
||||
]
|
||||
return "\n".join(" " * indent + line for line in lines) + "\n"
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH = """
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
|
||||
FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
@@ -283,6 +286,7 @@ class FmhaBwdDQDKDVTileSize:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -302,6 +306,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
@@ -399,43 +404,97 @@ class FmhaBwdDQDKDVKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size.
|
||||
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if dtype == "fp32" and tr_load == "f":
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f":
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t":
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
|
||||
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
] # fmt: skip
|
||||
else:
|
||||
class KernelComponentFactoryBase:
|
||||
pass
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait(
|
||||
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if tr_load == "t":
|
||||
return []
|
||||
if dtype in ["fp32"]:
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4,bhdq,bhdv,
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
return []
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
@staticmethod
|
||||
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
if dtype in ["fp16", "bf16"] and tr_load == "t":
|
||||
results.extend([
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
|
||||
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
]) # fmt: skip
|
||||
return results
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@staticmethod
|
||||
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if tr_load == "t":
|
||||
return []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
|
||||
FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
] # fmt: skip
|
||||
return []
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx950"):
|
||||
return KernelComponentFactoryGfx950
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} =
|
||||
@@ -445,7 +504,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = M0 = */ 64,
|
||||
/* BlockSize = M0 = */ {F_bm0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
@@ -459,10 +518,8 @@ using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
using dot_do_o_trait_{F_idx} =
|
||||
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -471,34 +528,38 @@ float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_spad: str # true/false
|
||||
F_dvpad: str #
|
||||
F_mode: str # value from MODE_MAP
|
||||
@@ -508,8 +569,10 @@ class FmhaBwdOGradDotOKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_bm0,
|
||||
F_spad=BOOL_MAP[self.F_spad],
|
||||
F_dvpad=BOOL_MAP[self.F_dvpad],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
@@ -529,7 +592,7 @@ class FmhaBwdOGradDotOKernel:
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
|
||||
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
@@ -538,10 +601,14 @@ class FmhaBwdOGradDotOKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_convert_dq_trait_{F_idx} =
|
||||
@@ -573,10 +640,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_deterministic},
|
||||
{F_bn0}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -585,32 +650,34 @@ float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_confi
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -627,6 +694,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_bm0,
|
||||
@@ -664,11 +732,12 @@ class FmhaBwdConvertQGradKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdApiTrait:
|
||||
arch: ArchTrait
|
||||
idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim: int
|
||||
@@ -705,10 +774,10 @@ class FmhaBwdApiTrait:
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true" # always support
|
||||
return "true /*spad1d is always true in group mode*/"
|
||||
elif self.spad1d == "t":
|
||||
return f"a.seqlen_q % {M0_1D} != 0"
|
||||
else: # self.spad1d == 'f'
|
||||
return f"true /*a.seqlen_q % {M0_1D} != 0*/"
|
||||
else: # self.spad1d == "f"
|
||||
return f"a.seqlen_q % {M0_1D} == 0"
|
||||
|
||||
@property
|
||||
@@ -725,10 +794,17 @@ class FmhaBwdApiTrait:
|
||||
else:
|
||||
return f"a.hdim_v % {self.dvpad} == 0"
|
||||
|
||||
@property
|
||||
def max_seq_q_cond(self) -> str:
|
||||
if self.tile.max_seq_q != 0:
|
||||
return f" && (a.seqlen_q <= {self.tile.max_seq_q})"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
return "&& (a.seqlen_k <= 256)"
|
||||
return " && (a.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -745,9 +821,11 @@ class FmhaBwdApiTrait:
|
||||
|
||||
F_dvpad = "t" if self.dvpad else "f"
|
||||
return FmhaBwdOGradDotOKernel(
|
||||
F_arch=self.arch,
|
||||
F_idx=self.idx,
|
||||
F_hdim=self.hdim,
|
||||
F_dtype=self.dtype,
|
||||
F_bm0=M0_1D,
|
||||
F_spad=self.spad1d,
|
||||
F_dvpad=F_dvpad,
|
||||
F_mode=self.mode,
|
||||
@@ -757,6 +835,7 @@ class FmhaBwdApiTrait:
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(
|
||||
F_arch=self.arch,
|
||||
F_idx=self.idx,
|
||||
F_hdim=self.hdim,
|
||||
F_dtype=self.dtype,
|
||||
@@ -782,6 +861,7 @@ class FmhaBwdApiTrait:
|
||||
|
||||
F_dpad = "t" if self.dpad else "f"
|
||||
return FmhaBwdConvertQGradKernel(
|
||||
F_arch=self.arch,
|
||||
F_idx=self.idx,
|
||||
F_hdim=self.hdim,
|
||||
F_dtype=self.dtype,
|
||||
@@ -798,28 +878,25 @@ class FmhaBwdApiTrait:
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
)
|
||||
|
||||
self.dq_dk_dv_pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait: FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][
|
||||
trait.hdim
|
||||
].append(copy.copy(trait))
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.dq_dk_dv_pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@staticmethod
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
inners = ""
|
||||
i = 0
|
||||
for trait in traits:
|
||||
for i_trait, trait in enumerate(traits):
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(
|
||||
F_if=self.if_(i),
|
||||
F_if=if_(i_trait),
|
||||
F_arch=trait.arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
@@ -840,27 +917,18 @@ class FmhaBwdApiPool:
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_max_seq_q_cond=trait.max_seq_q_cond,
|
||||
F_cond_extra=trait.extra_cond,
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0,
|
||||
)
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@staticmethod
|
||||
def trload_sort_key(tf):
|
||||
return 0 if tf == "t" else 1 # sort 't' before 'f'
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(max_seq_q):
|
||||
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_cond(max_seq_q: int) -> str:
|
||||
if max_seq_q == 0:
|
||||
return "true /* no seqlen_q limit */"
|
||||
else:
|
||||
return f"a.seqlen_q <= {max_seq_q}"
|
||||
def max_seq_q_sort_key(trait):
|
||||
return (
|
||||
trait.tile.max_seq_q if trait.tile.max_seq_q != 0 else 1000000
|
||||
) # sort 0 to the end
|
||||
|
||||
@staticmethod
|
||||
def dtype_cond(dtype: str) -> str:
|
||||
@@ -872,42 +940,34 @@ class FmhaBwdApiPool:
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true /* no trload requirement */"}
|
||||
per_tr_load = ""
|
||||
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
|
||||
per_max_seq_q = ""
|
||||
for max_seq_q in sorted(
|
||||
self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key
|
||||
):
|
||||
per_dtypes = ""
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
|
||||
per_hdim_case = ""
|
||||
for k, hdim in enumerate(
|
||||
self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]
|
||||
):
|
||||
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_=k, F_cond=self.hdim_cond(hdim), F_body=inners
|
||||
)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
|
||||
per_arch = ""
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()):
|
||||
per_dtypes = ""
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = ""
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key)
|
||||
inners = self._api_inners(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners
|
||||
)
|
||||
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(
|
||||
F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_BWD_API_COND_STATEMENT(
|
||||
F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4
|
||||
per_arch += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_arch, F_cond=arch.device_name_check, F_body=per_dtypes
|
||||
)
|
||||
if not per_tr_load:
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a; (void)has_load_tr;"
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch=per_tr_load)
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
|
||||
F_dispatch=indent(per_arch)
|
||||
)
|
||||
return result.replace("\n\n", "\n")
|
||||
|
||||
|
||||
def get_bwd_blobs(
|
||||
filter_list: str, receipt, mask_impl, optdim_list
|
||||
targets: List[str], filter_list: str, receipt, mask_impl, optdim_list
|
||||
) -> Tuple[
|
||||
FmhaBwdApiPool,
|
||||
List[FmhaBwdOGradDotOKernel],
|
||||
@@ -922,14 +982,19 @@ def get_bwd_blobs(
|
||||
filter_convert_dq = filters[1]
|
||||
filter_dq_dk_dv = filters[2]
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
|
||||
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = OrderedDict()
|
||||
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = OrderedDict()
|
||||
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = OrderedDict()
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
for factory, dtype, tr_load in itertools.product(
|
||||
factories, BWD_DTYPE_MAP.keys(), ["t", "f"]
|
||||
):
|
||||
tiles: Any = factory.get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
spad1d_options = ["f", "t"]
|
||||
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
|
||||
tf = ["t", "f"]
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, (
|
||||
@@ -942,7 +1007,7 @@ def get_bwd_blobs(
|
||||
BIAS_MAP.keys(),
|
||||
tf,
|
||||
DROPOUT_MAP.keys(),
|
||||
tf,
|
||||
spad1d_options,
|
||||
dpad_options,
|
||||
tf,
|
||||
):
|
||||
@@ -958,6 +1023,8 @@ def get_bwd_blobs(
|
||||
continue
|
||||
if "wg32" in dropout:
|
||||
continue
|
||||
if spad1d == "f" and tile.max_seq_q != 0 and tile.max_seq_q < M0_1D:
|
||||
continue # max_seq_q < M0_1D requires padding
|
||||
if tr_load == "t":
|
||||
# tr_load can only work with 8 pad
|
||||
if dpad != dvpad or dpad == 1:
|
||||
@@ -970,6 +1037,7 @@ def get_bwd_blobs(
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
t = FmhaBwdApiTrait(
|
||||
arch=factory.arch,
|
||||
idx=0,
|
||||
hdim=hdim,
|
||||
dtype=dtype,
|
||||
@@ -989,10 +1057,10 @@ def get_bwd_blobs(
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
@@ -1076,10 +1144,15 @@ def get_bwd_blobs(
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
filter_list: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
|
||||
filter_list, receipt, mask_impl, optdim_list
|
||||
targets, filter_list, receipt, mask_impl, optdim_list
|
||||
)
|
||||
update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api)
|
||||
for k in kernels_dot_do_o:
|
||||
@@ -1091,10 +1164,15 @@ def write_blobs(
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, filter_list: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
filter_list: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
|
||||
filter_list, receipt, mask_impl, optdim_list
|
||||
targets, filter_list, receipt, mask_impl, optdim_list
|
||||
)
|
||||
with file_path.open("a") as f:
|
||||
for k in kernels_dot_do_o:
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
@@ -23,7 +25,7 @@ from codegen.cpp_symbol_map import (
|
||||
BIAS_MAP,
|
||||
get_mask_map,
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
@@ -31,13 +33,17 @@ DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -99,10 +105,8 @@ using fmha_kernel_{F_idx} =
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -110,8 +114,10 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp"
|
||||
@@ -148,13 +154,13 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
if(!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
@@ -162,32 +168,33 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
FMHA_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@@ -207,6 +214,7 @@ class CppConstraint:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
@@ -413,40 +421,35 @@ class FmhaFwdPipeline:
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
max_bm0 = max((t.bm0 for t in traits), default=0)
|
||||
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
|
||||
pool_by_dtype.items()
|
||||
):
|
||||
max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0)
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
@@ -479,23 +482,24 @@ class FmhaFwdApiPool:
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=hdim_v,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_tr_load:
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -533,6 +537,7 @@ class FmhaFwdTileSize:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -545,6 +550,7 @@ class FmhaFwdKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
@@ -596,10 +602,11 @@ class FmhaFwdKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
@@ -627,12 +634,16 @@ class FmhaFwdKernel:
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
class KernelComponentFactoryGfx9:
|
||||
arch = ArchTrait(
|
||||
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
|
||||
)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp32":
|
||||
if dtype in ["fp32"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
@@ -645,10 +656,10 @@ class KernelComponentFactory:
|
||||
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp16" or dtype == "bf16":
|
||||
elif dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
@@ -656,18 +667,18 @@ class KernelComponentFactory:
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "fp8bf16":
|
||||
elif dtype in ["fp8", "fp8bf16"]:
|
||||
return {
|
||||
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8fp32":
|
||||
elif dtype in ["fp8fp32"]:
|
||||
return {
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
@@ -680,7 +691,7 @@ class KernelComponentFactory:
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: currently for qr pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ["fp32"]:
|
||||
@@ -719,18 +730,8 @@ class KernelComponentFactory:
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
if (
|
||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||
and logits == "f"
|
||||
and bias == "no"
|
||||
and dropout == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
@@ -746,29 +747,128 @@ class KernelComponentFactory:
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
pipelines = KernelComponentFactoryGfx9.get_pipelines(
|
||||
dtype, hdim, hdim_v, receipt, mask_impl
|
||||
)
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
if (
|
||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||
and logits == "f"
|
||||
and bias == "no"
|
||||
and dropout == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12:
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8fp32"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactoryGfx9):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1":
|
||||
return CustomFactory
|
||||
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx950"):
|
||||
return KernelComponentFactoryGfx950
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
@@ -791,7 +891,8 @@ def get_fwd_blobs(
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != "no" or pipeline.F_dropout == "t":
|
||||
continue
|
||||
if dtype != "fp32":
|
||||
if factory.arch.name.startswith("gfx9") and dtype != "fp32":
|
||||
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
|
||||
if pipeline.tag != "qr_async_trload" and (
|
||||
((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128)
|
||||
or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)
|
||||
@@ -811,6 +912,7 @@ def get_fwd_blobs(
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -918,19 +1020,33 @@ def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
FWD_DTYPE_MAP,
|
||||
@@ -16,16 +19,21 @@ from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
ROPE_CHECK_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdApiTrait,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_ARCH,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
@@ -55,10 +63,8 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
|
||||
float fmha_fwd_appendkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -66,31 +72,37 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
|
||||
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
|
||||
FMHA_FWD_APPENDKV_API = """
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """{F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
return fmha_fwd_appendkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVApiTrait:
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
arch: ArchTrait
|
||||
# sync with fmha_fwd_appendkv_traits, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
bs: int # tile size along q seqlen
|
||||
@@ -178,62 +190,70 @@ class FmhaFwdAppendKVPipeline:
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
def register_traits(self, trait: FmhaFwdAppendKVApiTrait) -> None:
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope],
|
||||
F_bs=trait.bs,
|
||||
F_bsk=trait.bsk,
|
||||
F_bd=trait.bd,
|
||||
F_bdv=trait.bdv,
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope],
|
||||
F_bs=trait.bs,
|
||||
F_bsk=trait.bsk,
|
||||
F_bd=trait.bd,
|
||||
F_bdv=trait.bdv,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_hdim_v=hdim,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_dtypes:
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
|
||||
F_dispatch=per_dtypes
|
||||
F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@@ -254,6 +274,7 @@ class FmhaFwdAppendKVTileSize:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -265,6 +286,7 @@ class FmhaFwdAppendKVKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs=self.F_tile.F_bs,
|
||||
@@ -293,10 +315,11 @@ class FmhaFwdAppendKVKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
|
||||
return FmhaFwdAppendKVApiTrait(
|
||||
arch=self.F_arch,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
bs=self.F_tile.F_bs,
|
||||
@@ -313,31 +336,26 @@ class FmhaFwdAppendKVKernel:
|
||||
)
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
class KernelComponentFactoryBase:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_fwd_appendkv_blobs(
|
||||
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
@@ -347,19 +365,18 @@ def get_fwd_appendkv_blobs(
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
|
||||
# applying rotary embedding, so I just use 't' in inter/half pipelines
|
||||
for vlayout in ["row", "col"]:
|
||||
for pagedkv in ["t", "f"]:
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
|
||||
for vlayout, pagedkv in itertools.product(["row"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline("row", "t", "t", "t", "t", "no", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -367,18 +384,45 @@ def get_fwd_appendkv_blobs(
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_appendkv_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str in d.keys():
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for pipeline in factory.get_pipelines(dtype, hdim):
|
||||
k = FmhaFwdAppendKVKernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -418,18 +462,23 @@ def get_fwd_appendkv_blobs(
|
||||
|
||||
|
||||
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(
|
||||
kernel_filter, receipt, mask_impl, optdim_list
|
||||
targets, kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
@@ -437,11 +486,16 @@ def write_blobs(
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(
|
||||
kernel_filter, receipt, mask_impl, optdim_list
|
||||
targets, kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
PIPELINE_ENUM_MAP,
|
||||
@@ -21,32 +23,29 @@ from codegen.cpp_symbol_map import (
|
||||
get_mask_map,
|
||||
BOOL_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdTileSize,
|
||||
DTYPE_BITS,
|
||||
K0_MAX_SUBMAX_MAP,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_ARCH,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
64: 64,
|
||||
96: 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
|
||||
}
|
||||
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
@@ -113,17 +112,15 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}}; // struct instance
|
||||
}} // anonymous namespace
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
@@ -147,7 +144,7 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// we don't check every seqlen_k values for kvcache
|
||||
@@ -165,14 +162,20 @@ void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, f
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
|
||||
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
namespace {{
|
||||
@@ -213,18 +216,16 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}}; // struct instance
|
||||
}} // anonymous namespace
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
|
||||
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if (a.num_splits <= 8) {{
|
||||
instance<3>::run(s, a);
|
||||
@@ -240,73 +241,79 @@ void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_conf
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp"
|
||||
FMHA_FWD_SPLITKV_API = """
|
||||
#include <iostream>
|
||||
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_, typename Arch>
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< std::flush;
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_, Arch>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_, Arch>()
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, {F_bn1comb}>::kM0;
|
||||
|
||||
// make sure we can reuse the padding flags in combine kernels
|
||||
static_assert({F_bm0} % kM0 == 0);
|
||||
static_assert({F_bn1} % 32 == 0);
|
||||
// make sure we can reuse the padding flags in combine kernels
|
||||
static_assert({F_bm0} % kM0 == 0);
|
||||
static_assert({F_bn1} % {F_bn1comb} == 0);
|
||||
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
return fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}}
|
||||
return fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
hdim: int
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
@@ -326,6 +333,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
pagedkv: str
|
||||
bn1comb: int # tile size along v head_dim of combine kernel
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -523,71 +531,80 @@ class FmhaFwdSplitKVCombinePipeline:
|
||||
|
||||
class FmhaFwdSplitKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdSplitKVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_bn1comb=trait.bn1comb,
|
||||
)
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_hdim_v=hdim,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_dtypes:
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(
|
||||
F_dispatch=per_dtypes
|
||||
F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@@ -605,6 +622,7 @@ class FmhaFwdSplitKVCombineTileSize:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -615,8 +633,10 @@ class FmhaFwdSplitKVKernel:
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
assert self.F_pipeline.F_lse == "t"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
@@ -666,36 +686,12 @@ class FmhaFwdSplitKVKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
|
||||
return FmhaFwdSplitKVApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
)
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -707,6 +703,7 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
@@ -730,85 +727,33 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
"32": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"96": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_fwd_splitkv_blobs(
|
||||
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
Kernel = FmhaFwdSplitKVKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
|
||||
class KernelComponentFactoryBase:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: currently for qr pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
|
||||
):
|
||||
pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -816,18 +761,122 @@ def get_fwd_splitkv_blobs(
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
|
||||
@staticmethod
|
||||
def get_combine_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
|
||||
Pipeline = FmhaFwdSplitKVCombinePipeline
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for spad, dvpad, lse in itertools.product(
|
||||
["t", "f"], ["t", "f"], ["t", "f"]
|
||||
):
|
||||
pipelines.append(Pipeline("unused", spad, dvpad, lse, squant))
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse kernels
|
||||
for spad, dvpad in itertools.product(["t", "f"], ["t", "f"]):
|
||||
pipelines.append(Pipeline("unused", spad, dvpad, "f", squant))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
@staticmethod
|
||||
def get_combine_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
# Possible values of F_bn1: 8, 16, 32
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
"32": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"96": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
# "160" : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"128": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
"256": FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_splitkv_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> List[FmhaFwdSplitKVKernel]:
|
||||
Kernel = FmhaFwdSplitKVKernel
|
||||
|
||||
gen = list()
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
@@ -839,6 +888,7 @@ def get_fwd_splitkv_blobs(
|
||||
):
|
||||
continue
|
||||
k = Kernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -892,55 +942,34 @@ def get_fwd_splitkv_blobs(
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
return gen
|
||||
|
||||
|
||||
def get_fwd_splitkv_combine_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list
|
||||
) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
Pipeline = FmhaFwdSplitKVCombinePipeline
|
||||
Kernel = FmhaFwdSplitKVCombineKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for spad, dvpad, lse in itertools.product(
|
||||
["t", "f"], ["t", "f"], ["t", "f"]
|
||||
):
|
||||
pipelines.append(Pipeline("unused", spad, dvpad, lse, squant))
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse kernels
|
||||
pipelines.append(Pipeline("unused", "f", "f", "f", squant))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_combine_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for pipeline in factory.get_combine_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = Kernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -980,43 +1009,102 @@ def get_fwd_splitkv_combine_blobs(
|
||||
def write_single_kernel(
|
||||
kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path
|
||||
) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_splitkv_api(api_pool: FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
|
||||
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
|
||||
file_path.write_text(api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, filter_list: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
filter_list: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
filter_list = filter_list.split("@")
|
||||
filter_list.extend([""] * (2 - len(filter_list)))
|
||||
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
for kernel in kernels:
|
||||
combine_kernels = get_fwd_splitkv_combine_blobs(
|
||||
targets, filter_list[0], receipt, optdim_list
|
||||
)
|
||||
for kernel in combine_kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(
|
||||
filter_list[1], receipt, mask_impl, optdim_list
|
||||
kernels = get_fwd_splitkv_blobs(
|
||||
targets, filter_list[1], receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
|
||||
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
|
||||
for kernel in kernels:
|
||||
combine_ks = [
|
||||
k
|
||||
for k in combine_kernels
|
||||
if k.F_arch == kernel.F_arch
|
||||
and k.F_hdim == kernel.F_hdim
|
||||
and k.F_dtype == kernel.F_dtype
|
||||
and k.F_mode == kernel.F_mode
|
||||
and k.F_pipeline.F_spad == kernel.F_pipeline.F_spad
|
||||
and k.F_pipeline.F_dvpad == kernel.F_pipeline.F_dvpad
|
||||
and k.F_pipeline.F_lse == "f"
|
||||
and k.F_pipeline.F_squant == kernel.F_pipeline.F_squant
|
||||
]
|
||||
assert len(combine_ks) == 1, (
|
||||
f"{len(combine_ks)} matching FmhaFwdSplitKVCombineKernel for {kernel}"
|
||||
)
|
||||
combine_kernel = combine_ks[0]
|
||||
api_pool.register_traits(
|
||||
FmhaFwdSplitKVApiTrait(
|
||||
arch=kernel.F_arch,
|
||||
pipeline_tag=kernel.F_pipeline.tag,
|
||||
hdim=kernel.F_hdim,
|
||||
dtype=kernel.F_dtype,
|
||||
mode=kernel.F_mode,
|
||||
bm0=kernel.F_tile.F_bm0,
|
||||
bn0=kernel.F_tile.F_bn0,
|
||||
bk0=kernel.F_tile.F_bk0,
|
||||
bn1=kernel.F_tile.F_bn1,
|
||||
bk1=kernel.F_tile.F_bk1,
|
||||
bk0max=kernel.F_tile.F_bk0max,
|
||||
vlayout=kernel.F_pipeline.F_vlayout,
|
||||
logits=kernel.F_pipeline.F_logits,
|
||||
mask=kernel.F_pipeline.F_mask,
|
||||
bias=kernel.F_pipeline.F_bias,
|
||||
lse=kernel.F_pipeline.F_lse,
|
||||
squant=kernel.F_pipeline.F_squant,
|
||||
pagedkv=kernel.F_pipeline.F_pagedkv,
|
||||
spad=kernel.F_pipeline.F_spad,
|
||||
skpad=kernel.F_pipeline.F_skpad,
|
||||
dpad=kernel.F_pipeline.F_dpad,
|
||||
dvpad=kernel.F_pipeline.F_dvpad,
|
||||
bn1comb=combine_kernel.F_tile.F_bn1,
|
||||
)
|
||||
)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, filter_list: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
filter_list: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
filter_list = filter_list.split("@")
|
||||
filter_list.extend([""] * (2 - len(filter_list)))
|
||||
|
||||
with file_path.open("a") as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
kernels = get_fwd_splitkv_combine_blobs(
|
||||
targets, filter_list[0], receipt, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_splitkv_blobs(
|
||||
filter_list[1], receipt, mask_impl, optdim_list
|
||||
kernels = get_fwd_splitkv_blobs(
|
||||
targets, filter_list[1], receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
@@ -21,24 +23,27 @@ from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
DTYPE_BITS,
|
||||
K0_MAX_SUBMAX_MAP,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_ARCH,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -98,10 +103,8 @@ using fmha_kernel_{F_idx} =
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -109,38 +112,35 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
|
||||
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_pagedkv_<trait_>(s, a);
|
||||
}}
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
@@ -327,71 +327,79 @@ class FmhaFwdPipeline:
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_hdim_v=trait.bn1,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_dtypes:
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -428,6 +436,7 @@ class FmhaFwdTileSize:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
@@ -440,6 +449,7 @@ class FmhaFwdKernel:
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
@@ -490,10 +500,11 @@ class FmhaFwdKernel:
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
@@ -519,37 +530,12 @@ class FmhaFwdKernel:
|
||||
)
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
class KernelComponentFactoryBase:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
|
||||
# TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
@@ -576,19 +562,85 @@ def get_fwd_blobs(
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
# if pipeline.F_pagedkv == 'f':
|
||||
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
||||
# if pipeline.F_pagedkv == "f":
|
||||
# continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
@@ -605,6 +657,7 @@ def get_fwd_blobs(
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -674,27 +727,41 @@ def get_fwd_blobs(
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
@@ -19,3 +21,51 @@ def update_file(file_path, content):
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
|
||||
@@ -453,15 +453,15 @@ struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
@@ -474,13 +474,13 @@ struct fmha_bwd_dot_do_o_traits_
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -494,13 +494,13 @@ struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
|
||||
@@ -1159,7 +1159,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -1210,7 +1210,7 @@ struct fmha_fwd_pagedkv_traits_
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -1259,10 +1259,10 @@ struct fmha_fwd_splitkv_traits_
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -1285,10 +1285,10 @@ struct fmha_fwd_splitkv_combine_traits_
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
@@ -1322,10 +1322,10 @@ struct fmha_fwd_appendkv_traits_
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
|
||||
@@ -1200,7 +1200,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
};
|
||||
|
||||
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
|
||||
auto run_appendkv = [&]([[maybe_unused]] const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(need_append_kvcache)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
@@ -38,6 +39,7 @@ assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
@@ -54,11 +56,12 @@ def write_blobs(
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
@@ -74,7 +77,7 @@ def list_blobs(
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -82,6 +85,12 @@ if __name__ == "__main__":
|
||||
prog="generate",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="gfx9,gfx950",
|
||||
required=False,
|
||||
help="list of GPU targets, separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
@@ -142,6 +151,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
targets = args.targets.split(",")
|
||||
api_list = args.direction.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
@@ -149,6 +159,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
targets,
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
@@ -158,6 +169,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
targets,
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
|
||||
@@ -94,7 +94,7 @@ run_fp8_tests() {
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
@@ -105,7 +105,7 @@ run_fp8bf16_tests() {
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
@@ -114,9 +114,9 @@ run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for hdim in 128 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
@@ -313,6 +313,12 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_
|
||||
}
|
||||
|
||||
// Architecture tags
|
||||
struct gfx9_t
|
||||
{
|
||||
};
|
||||
struct gfx950_t
|
||||
{
|
||||
};
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ >= 20
|
||||
#if __clang_major__ >= 20 && !(defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
|
||||
@@ -168,8 +168,12 @@ uint16_t float_to_bf16_rtn_asm(float f)
|
||||
static constexpr uint32_t FP32_NAN = 0x7fff0000;
|
||||
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
|
||||
|
||||
#if defined(__GFX9__)
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
#else
|
||||
uint32_t check_nan;
|
||||
#endif
|
||||
uint32_t tmp;
|
||||
asm volatile("\n \
|
||||
v_cmp_u_f32 %0, %2, %2 \n \
|
||||
@@ -204,8 +208,12 @@ uint16_t float_to_bf16_rta_asm(float f)
|
||||
const uint32_t low_nan = 0x7fff;
|
||||
const uint32_t hi_nan = 0x7fff0000;
|
||||
|
||||
#if defined(__GFX9__)
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
#else
|
||||
uint32_t check_nan;
|
||||
#endif
|
||||
|
||||
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
|
||||
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
|
||||
|
||||
@@ -5,11 +5,8 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -184,7 +184,7 @@ namespace impl {
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx94__) || defined(__gfx12__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
@@ -195,7 +195,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin requires the old value, and
|
||||
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
|
||||
// so we prepare an uninitialized variable purposely, and turn off the warning
|
||||
int dummy_old;
|
||||
@@ -209,13 +209,12 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
x,
|
||||
true); // true -> WORD1
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
|
||||
vec_t d = bit_cast<vec_t>(y);
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -28,6 +28,19 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Arch, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
Kernel{}(args...);
|
||||
#else
|
||||
(..., (ignore = args, 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
// the KernelImpl should be a class without non-static data member, or let's say
|
||||
@@ -35,11 +48,27 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
|
||||
// Arch can be used to support linking multiple object files that have the same kernel compiled for
|
||||
// different architectures. In this case each object file has to use a different tag (gfx9_t,
|
||||
// gfx12_t etc.), so the kernel will have different symbols for each architecture.
|
||||
//
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename Arch = void,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
CK_TILE_HOST auto
|
||||
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
|
||||
const auto kernel = []() {
|
||||
if constexpr(std::is_void_v<Arch>)
|
||||
{
|
||||
return kentry<MinBlockPerCu, KernelImpl, Args...>;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kentry<Arch, MinBlockPerCu, KernelImpl, Args...>;
|
||||
}
|
||||
}();
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
};
|
||||
|
||||
@@ -692,7 +692,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -677,7 +677,17 @@ struct FmhaBwdDQDKDVKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -1171,6 +1181,21 @@ struct FmhaBwdDQDKDVKernel
|
||||
scale_rp_undrop,
|
||||
dropout);
|
||||
|
||||
#if defined(__gfx12__)
|
||||
// Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
|
||||
// placed in divergent branches used to store padded tensors (when some lanes are
|
||||
// inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors
|
||||
// prevents the compiler doing this.
|
||||
if constexpr(kPadHeadDimQ > 0)
|
||||
{
|
||||
impl::insert_dummy_dep(dk_acc_tile.get_thread_buffer());
|
||||
}
|
||||
if constexpr(kPadHeadDimV > 0)
|
||||
{
|
||||
impl::insert_dummy_dep(dv_acc_tile.get_thread_buffer());
|
||||
}
|
||||
#endif
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
|
||||
}
|
||||
@@ -1241,7 +1266,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
"_b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -1371,7 +1396,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
@@ -1678,7 +1703,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ struct FmhaFwdAppendKVKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -143,41 +143,41 @@ struct FmhaFwdAppendKVKernel
|
||||
{
|
||||
};
|
||||
|
||||
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool has_mask,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool has_mask,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
{
|
||||
Kargs kargs{
|
||||
{q_ptr,
|
||||
@@ -255,7 +255,7 @@ struct FmhaFwdAppendKVKernel
|
||||
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
|
||||
@@ -1079,7 +1079,17 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -865,7 +865,17 @@ struct FmhaFwdPagedKVKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -37,7 +37,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -127,7 +127,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
@@ -185,7 +185,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
@@ -240,8 +240,10 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// Recalculate kM0 = get_warp_size() / NThreads on host
|
||||
const index_t m0 = (is_wave32() ? 32 : 64) / FmhaPipeline::Problem::NThreads;
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, m0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
@@ -266,7 +268,17 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -344,7 +356,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q),
|
||||
make_tuple(kargs.split_stride_lse_acc, 1),
|
||||
make_tuple(kargs.split_stride_lse_acc, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentLSEacc>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -358,11 +370,11 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
|
||||
// read kNumWarps * (kM0, kN1) o_acc tiles simultaneously by kNumWarps warps
|
||||
const auto o_acc_dram_view = pad_tensor_view(
|
||||
o_acc_dram_naive,
|
||||
make_tuple(
|
||||
@@ -469,7 +481,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.row_stride_o, 1),
|
||||
make_tuple(kargs.row_stride_o, number<1>{}),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ struct FmhaFwdSplitKVKernel
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
@@ -279,7 +279,7 @@ struct FmhaFwdSplitKVKernel
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
@@ -409,7 +409,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
@@ -574,7 +574,17 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
@@ -683,26 +682,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(AccDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t K2 = GetAlignmentPostQGradAcc<Problem>();
|
||||
constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M2 = get_warp_size() / K1;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<2>, sequence<2, 3>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2, 3>,
|
||||
sequence<0, 0, 1>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 3>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 3, 3>,
|
||||
sequence<0, 0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
@@ -711,27 +710,25 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(AccDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t K2 = GetAlignmentPostQGrad<Problem>();
|
||||
constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M2 = get_warp_size() / K1;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M1 * M2);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr;
|
||||
|
||||
@@ -31,59 +31,33 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
|
||||
@@ -273,16 +273,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
|
||||
auto o_acc_4_dram_window =
|
||||
// First each warp processes its own part of splits
|
||||
|
||||
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
|
||||
auto o_acc_dram_window =
|
||||
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
o_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
o_acc_dram_block_window_tmp.get_window_origin(),
|
||||
o_acc_4_dist);
|
||||
o_acc_dist);
|
||||
|
||||
// shape=[4 * KM0, kN1]
|
||||
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
|
||||
clear_tile(o_acc_4);
|
||||
// shape=[kNumWarps * KM0, kN1]
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
|
||||
|
||||
@@ -291,73 +293,73 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
// each warp handles a [KM0, kN1] tile
|
||||
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
|
||||
{
|
||||
auto o_tile = load_tile(o_acc_4_dram_window);
|
||||
auto o_tile = load_tile(o_acc_dram_window);
|
||||
const index_t i_split = split_start + get_warp_id();
|
||||
const index_t row_start = kM0 * get_warp_id();
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
o_acc_4.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
|
||||
o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
|
||||
}
|
||||
|
||||
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
|
||||
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
|
||||
|
||||
{
|
||||
auto o_acc_4_lds_window = [&]() {
|
||||
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0});
|
||||
}();
|
||||
store_tile(o_acc_4_lds_window, o_acc_4);
|
||||
}
|
||||
|
||||
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
|
||||
|
||||
auto o_acc_4_lds_window = [&]() {
|
||||
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
|
||||
}();
|
||||
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
static_for<0, kNumWarps, 1>{}([&](auto) {
|
||||
auto o_acc_in = load_tile(o_acc_4_lds_window);
|
||||
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) += o_acc_in(i_j_idx);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
|
||||
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_4_lds_window, {kM0, 0});
|
||||
move_tile_window(o_acc_dram_window, {kNumWarps * kM0, 0});
|
||||
}
|
||||
|
||||
// Then each warps combines partial o_acc results into one
|
||||
|
||||
// kNumWarps o_acc tiles in LDS. shape=[kNumWarps * kM0, kN1]
|
||||
OaccDataType* o_acc_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
|
||||
|
||||
{
|
||||
auto o_acc_lds_store_window = [&]() {
|
||||
auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0});
|
||||
}();
|
||||
store_tile(o_acc_lds_store_window, o_acc);
|
||||
}
|
||||
|
||||
auto o_acc_result_dist = Policy::template MakeOaccResultDramTileDistribution<Problem>();
|
||||
|
||||
auto o_acc_lds_load_window = [&]() {
|
||||
auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
|
||||
auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
|
||||
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_result_dist);
|
||||
}();
|
||||
|
||||
auto o_acc_result = make_static_distributed_tensor<OaccDataType>(o_acc_result_dist);
|
||||
clear_tile(o_acc_result);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
static_for<0, kNumWarps, 1>{}([&](auto) {
|
||||
auto o_acc_in = load_tile(o_acc_lds_load_window);
|
||||
|
||||
{
|
||||
constexpr auto spans = decltype(o_acc_result)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc_result(i_j_idx) += o_acc_in(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_lds_load_window, {kM0, 0});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
return tile_elementwise_in(o_acc_element_func, o_acc_result);
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindow,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -52,11 +52,11 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kNumWarps;
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
@@ -78,16 +78,16 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc()
|
||||
{
|
||||
return sizeof(typename Problem::OaccDataType) *
|
||||
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
MakeOaccLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
|
||||
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc<Problem>();
|
||||
}
|
||||
|
||||
// shape=[kMaxSplits, kM0]
|
||||
@@ -129,8 +129,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
|
||||
@@ -142,8 +142,9 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NPack>{}, number<NPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -156,8 +157,8 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
|
||||
@@ -169,21 +170,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NPack>{}, number<NPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return lse_acc_t_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding, shape=[4 * kM0, kN1]
|
||||
// 3d + padding, shape=[kNumWarps * kM0, kN1]
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccLdsBlockDescriptor()
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = 4 * Problem::kM0;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = kNumWarps * Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
constexpr index_t NPack =
|
||||
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
|
||||
@@ -191,17 +194,17 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<NPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto o_acc_lds_block_desc = transform_tensor_descriptor(
|
||||
o_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return o_acc_t_lds_block_desc;
|
||||
return o_acc_lds_block_desc;
|
||||
}
|
||||
|
||||
// shape=[kM0, kMaxSplits]
|
||||
@@ -235,12 +238,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
|
||||
// direction
|
||||
// similar to MakeOaccResultDramTileDistribution(), but duplicate same 1-warp encoding kNumWarps
|
||||
// times on M direction
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (kNumWarps * kM0)
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
|
||||
|
||||
@@ -252,7 +256,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<kNumWarps, M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 2>, sequence<3, 0>>,
|
||||
sequence<1, 2>,
|
||||
@@ -260,14 +264,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccResultDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNumWarps = Problem::kNumWarps;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
|
||||
static_assert(kNumWarps * get_warp_size() <= kMPerBlock * kNPerBlock);
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kNumWarps;
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -204,6 +204,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem
|
||||
|
||||
using BaseType::kM0;
|
||||
using BaseType::kN1;
|
||||
using BaseType::NThreads;
|
||||
|
||||
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
|
||||
|
||||
@@ -216,7 +217,7 @@ struct BlockFmhaSplitKVCombinePipelineProblem
|
||||
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
|
||||
static_assert(8 <= kMaxSplits);
|
||||
|
||||
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
|
||||
static constexpr index_t kNumWarps = 4;
|
||||
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
|
||||
|
||||
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
|
||||
|
||||
@@ -58,17 +58,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
using BlockGemm0 = remove_cvref_t<decltype(Policy::template GetQKBlockGemm<Problem>())>;
|
||||
static constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
static constexpr int NumMfmaInsts =
|
||||
(kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
@@ -298,7 +287,18 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm0 = [] {
|
||||
if constexpr(kQKHeaddim == 256)
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
|
||||
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
|
||||
{
|
||||
static_assert(NumMfmaInsts % 8 == 0);
|
||||
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
@@ -263,59 +263,33 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
|
||||
@@ -72,59 +72,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -238,7 +212,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
@@ -246,59 +220,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
|
||||
std::is_same_v<typename Problem::KDataType, float> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 16);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -481,7 +429,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
|
||||
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1019,15 +968,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::PDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -29,59 +29,40 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
if constexpr(((std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t>) ||
|
||||
(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t>)) &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
if constexpr(get_warp_size() == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2);
|
||||
using WG = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2);
|
||||
using WG = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
using WG = WarpGemmDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
using WG = WarpGemmDispatcher<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -384,9 +384,9 @@ using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
|
||||
@@ -50,6 +50,19 @@ struct CWarpDstrEncodingTrait
|
||||
typename Impl::kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct CTransposedWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<typename Impl::kCTPs2RHssMajor>,
|
||||
tuple<typename Impl::kCTPs2RHssMinor>,
|
||||
typename Impl::kCTYs2RHsMajor,
|
||||
typename Impl::kCTYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
|
||||
struct WarpGemmAttributeWmma
|
||||
{
|
||||
@@ -75,9 +88,11 @@ struct WarpGemmAttributeWmma
|
||||
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
|
||||
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
|
||||
|
||||
// kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input
|
||||
// kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input
|
||||
using CWarpDstrEncoding = typename CWarpDstrEncodingTrait<Impl>::type;
|
||||
// kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
|
||||
using CWarpDstrEncoding =
|
||||
std::conditional_t<kTransC,
|
||||
typename CTransposedWarpDstrEncodingTrait<Impl>::type,
|
||||
typename CWarpDstrEncodingTrait<Impl>::type>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
|
||||
@@ -60,6 +60,11 @@ struct WarpGemmAttributeWmmaImpl
|
||||
using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
|
||||
using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
|
||||
|
||||
using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
|
||||
using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
|
||||
using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
|
||||
using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool clamp = false, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
|
||||
@@ -46,6 +46,11 @@ struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCTPs2RHssMajor = sequence<2, 1>;
|
||||
using kCTPs2RHssMinor = sequence<1, 0>;
|
||||
using kCTYs2RHsMajor = sequence<2, 2>;
|
||||
using kCTYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
|
||||
// GFX12 specialization
|
||||
@@ -88,5 +93,10 @@ struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCTPs2RHssMajor = sequence<2, 1>;
|
||||
using kCTPs2RHssMinor = sequence<1, 0>;
|
||||
using kCTYs2RHsMajor = sequence<2, 2>;
|
||||
using kCTYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.")
|
||||
message( FATAL_ERROR "CK Tile MHA FAILED to generate a list of kernels via Python.")
|
||||
else()
|
||||
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_GEN_BLOBS)
|
||||
endif()
|
||||
@@ -74,4 +74,3 @@ add_instance_library(device_mha_instance ${device_files})
|
||||
if (TARGET device_mha_instance)
|
||||
add_dependencies(device_mha_instance generate_cpp_files)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt
|
||||
if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Values(std::tuple{24, 48},
|
||||
std::tuple{48, 48},
|
||||
std::tuple{72, 72},
|
||||
std::tuple{40, 88},
|
||||
std::tuple{96, 96},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
|
||||
@@ -38,7 +38,7 @@ struct TestConfigs
|
||||
static constexpr auto AppendKVHDimValues = std::array{
|
||||
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false, true};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
@@ -47,24 +47,18 @@ struct TestConfigs
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp8>
|
||||
{
|
||||
// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
static constexpr auto HDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto HDimValues =
|
||||
std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto SplitKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false};
|
||||
static constexpr bool squant = true;
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen)
|
||||
{
|
||||
// There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests
|
||||
return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
}
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = true;
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
|
||||
// return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp32>
|
||||
|
||||
Reference in New Issue
Block a user