mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
215
example/ck_tile/01_fmha/CMakeLists.txt
Normal file
215
example/ck_tile/01_fmha/CMakeLists.txt
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 and gfx12 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(BUILD_TESTING)
|
||||
# Build instances of all APIs for tests
|
||||
message(DEBUG "Enabling all FWD APIs of CK Tile FMHA for because testing is enabled")
|
||||
set(FMHA_FWD_ENABLE_APIS "all")
|
||||
endif()
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
foreach(api ${FMHA_FWD_ENABLE_APIS})
|
||||
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
|
||||
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
|
||||
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
list(JOIN INST_TARGETS , FMHA_TARGETS_ARG)
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,80,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api bwd
|
||||
--receipt 3
|
||||
--optdim 32,64,96,128,256
|
||||
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
|
||||
)
|
||||
|
||||
# Reduce building time by disabling instances that are not currently used in the gtests
|
||||
# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of
|
||||
# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when
|
||||
# there is no corresponding instance for parameters).
|
||||
if(BUILD_TESTING)
|
||||
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
|
||||
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
# as current cmake list, otherwise will not figure out the dependency properly
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile FMHA FWD kernels"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile FMHA BWD kernels"
|
||||
)
|
||||
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
|
||||
message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
|
||||
# to save build time, exclude the target from "all" target of "01_fmha" directory and its ancestors
|
||||
add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
|
||||
add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
|
||||
set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
|
||||
set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
|
||||
set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
# NOTE: this is dangerous since will change the whole kernel to flush denormals
|
||||
# WIP with compiler team for an exp2 intrinsic..., then remove this
|
||||
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
set(FMHA_FWD_FAST_EXP2 ON)
|
||||
endif()
|
||||
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests
|
||||
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests
|
||||
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests
|
||||
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
target_compile_options(${FMHA_FWD_INSTANCES}
|
||||
PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
|
||||
INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
|
||||
target_compile_options(${FMHA_BWD_INSTANCES}
|
||||
PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
|
||||
INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
# not using add_example_executable() to add this target, since we don't want this to be included in
|
||||
# "make all/install/check"
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
# not using add_example_executable() to add this target, since we don't want this to be included in
|
||||
# "make all/install/check"
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
167
example/ck_tile/01_fmha/README.md
Normal file
167
example/ck_tile/01_fmha/README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
# fused multi-head attention
|
||||
|
||||
This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
|
||||
|
||||
## build
|
||||
```
|
||||
# 1. In the root of composable_kernel project, create the build directory.
|
||||
[~/composable_kernel] mkdir build && cd build
|
||||
# 2. In the build directory, run the CMake wrapper script to generate the build system files. Replace <arch> with the gfx architectures string.
|
||||
[~/composable_kernel/build] ../script/cmake-ck-dev.sh .. <arch> -G Ninja
|
||||
# 3. In the build directory, run the build system recipe.
|
||||
[~/composable_kernel/build] ninja tile_example_fmha_fwd
|
||||
```
|
||||
Running the build recipe will produce the executable `tile_example_fmha_fwd`.
|
||||
|
||||
The executables reside in `bin` subdirectory of the build directory.
|
||||
|
||||
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`.
|
||||
|
||||
> [!NOTE]
|
||||
> `cmake-ck-dev.sh` is a CMake wrapper.
|
||||
>
|
||||
> The first argument is the path to composable_kernel sources.
|
||||
>
|
||||
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
|
||||
>
|
||||
> The remaining arguments are optional and are passed through to CMake.
|
||||
> E.g. `-G Ninja` specifies ninja as the build system.
|
||||
|
||||
## kernel
|
||||
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
|
||||
|
||||
There are 2 template parameters for this kernel template.
|
||||
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
|
||||
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
|
||||
|
||||
## codegen
|
||||
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
|
||||
|
||||
## executable
|
||||
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change)
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-mode kernel mode. 0:batch, 1:group (default:0)
|
||||
-b batch size (default:2)
|
||||
-h num of head, for q (default:8)
|
||||
-h_k num of head, for k/v, -1 means equal to h (default:-1)
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
|
||||
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
|
||||
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
|
||||
Provide positive strides per-batch to simulate physical padding on Q
|
||||
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
|
||||
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
|
||||
along seqlen, instead of packed, same as xformer kv_padding,
|
||||
must be greater than or equal to s_k
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
-qscale n or 0, no scaling (default:n)
|
||||
1: per-tensor quantization.
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
|
||||
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
|
||||
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
|
||||
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fmha_fwd.json)
|
||||
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
```
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## Padding Examples
|
||||
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
|
||||
|
||||
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
### hdim
|
||||
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default)
|
||||
|
||||
### group/batch mode
|
||||
Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below)
|
||||
|
||||
### MQA/GQA
|
||||
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.
|
||||
|
||||
### input/output permute, and `b*s*3*h*d`
|
||||
If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`.
|
||||
|
||||
### attention bias
|
||||
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
|
||||
|
||||
### alibi
|
||||
alibi is supported
|
||||
|
||||
### lse
|
||||
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
|
||||
|
||||
### vlayout
|
||||
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.
|
||||
|
||||
### attention mask
|
||||
we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right.
|
||||
Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
|
||||

|
||||
|
||||
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
|
||||
|
||||
| mask case| cmdline | FA style | xformer style |
|
||||
|----------|:-------------:|:-------------:|:-------------:|
|
||||
| no mask | `-mask=0`(default) | | |
|
||||
| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` |
|
||||
| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` |
|
||||
| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` |
|
||||
| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` |
|
||||
|
||||
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
|
||||
|
||||
### dropout
|
||||
TBD
|
||||
|
||||
### sequence padding and variable length support
|
||||
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
|
||||
|
||||
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
|
||||
|
||||
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
|
||||
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
114
example/ck_tile/01_fmha/bias.hpp
Normal file
114
example/ck_tile/01_fmha/bias.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "e" || t == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 2)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else if(t == "a" || t == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 1)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str == "1" || str == "e" || str == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
}
|
||||
else if(str == "2" || str == "a" || str == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
3
example/ck_tile/01_fmha/codegen/__init__.py
Normal file
3
example/ck_tile/01_fmha/codegen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
42
example/ck_tile/01_fmha/codegen/arch.py
Normal file
42
example/ck_tile/01_fmha/codegen/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchTrait:
|
||||
name: str
|
||||
preprocessor_check: str = field(default=None)
|
||||
device_name_check: str = field(default=None)
|
||||
tag: str = field(default=None)
|
||||
filename_suffix: str = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.preprocessor_check is None:
|
||||
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
||||
if self.device_name_check is None:
|
||||
object.__setattr__(
|
||||
self,
|
||||
"device_name_check",
|
||||
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
||||
)
|
||||
if self.tag is None:
|
||||
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
||||
if self.filename_suffix is None:
|
||||
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
||||
|
||||
|
||||
def get_factories_for_targets(
|
||||
targets: List[str], get_factory: Callable[[str], Any]
|
||||
) -> List[Any]:
|
||||
factories = dict()
|
||||
for target in targets:
|
||||
factory = get_factory(target)
|
||||
factories[factory.arch.name] = factory
|
||||
# Place more specific architectures first
|
||||
factories = sorted(
|
||||
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
||||
)
|
||||
return factories
|
||||
4
example/ck_tile/01_fmha/codegen/cmake_config.py
Normal file
4
example/ck_tile/01_fmha/codegen/cmake_config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
163
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Normal file
163
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp32": "FmhaFwdFp32",
|
||||
"fp16": "FmhaFwdFp16",
|
||||
"bf16": "FmhaFwdBf16",
|
||||
"fp8": "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32",
|
||||
"mxfp8": "FmhaFwdMxFp8",
|
||||
"mxfp4": "FmhaFwdMxFp4",
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
|
||||
|
||||
MASK_IMPL = {
|
||||
"generic": "ck_tile::GenericAttentionMask",
|
||||
"simplified": "ck_tile::SimplifiedGenericAttentionMask",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "FmhaMasks::NoMask",
|
||||
"causal": "FmhaMasks::CausalMask",
|
||||
"generic": "FmhaMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask_impl: str):
|
||||
if mask_impl == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask_impl == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_impl(mask: str) -> str:
|
||||
return "simplified" if mask.startswith("s_") else "generic"
|
||||
|
||||
|
||||
def get_mask_cpp_type(mask: str) -> str:
|
||||
return get_mask_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
return get_mask_check_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
"mx": "quant_scale_enum::mx",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
|
||||
}
|
||||
|
||||
# TODO: this is ugly
|
||||
BIAS_CHECK_MAP = {
|
||||
"no": "bias_enum::no_bias",
|
||||
"bias": "bias_enum::elementwise_bias",
|
||||
"alibi": "bias_enum::alibi",
|
||||
}
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no": "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32": "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16": "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd<true, false, true >",
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
"no": "t.has_dropout == false",
|
||||
"dropout_wg32": "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16": "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true",
|
||||
}
|
||||
|
||||
ROPE_MAP = {
|
||||
"no": "ck_tile::RotaryEmbeddingEnum::NONE",
|
||||
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
|
||||
}
|
||||
|
||||
ROPE_CHECK_MAP = {
|
||||
"no": "rope_enum::none",
|
||||
"inter": "rope_enum::interleaved",
|
||||
"half": "rope_enum::half_rotated",
|
||||
}
|
||||
|
||||
MODE_MAP = {"batch": "false", "group": "true"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
3
example/ck_tile/01_fmha/codegen/ops/__init__.py
Normal file
3
example/ck_tile/01_fmha/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
849
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
849
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
@@ -0,0 +1,849 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
MODE_MAP,
|
||||
LAYOUT_MAP,
|
||||
BIAS_CHECK_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
BIAS_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8": 8,
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cu) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cu = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
bias: str #
|
||||
lse: str #
|
||||
dropout: str
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
else:
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
elif self.pipeline_tag in ["qr", "qr_fp8"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_k % {self.bn0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
if self.F_bias != "no":
|
||||
n += f"_{self.F_bias}"
|
||||
else:
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_lse == "t":
|
||||
n += "_lse"
|
||||
else:
|
||||
n += "_nlse"
|
||||
|
||||
if self.F_dropout == "t":
|
||||
n += "_dropout"
|
||||
else:
|
||||
n += "_ndropout"
|
||||
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += " (void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_kname=self.name,
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
qscale = "no"
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
|
||||
):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if (
|
||||
pipeline.F_bias != "no"
|
||||
or pipeline.F_lse == "t"
|
||||
or pipeline.F_dropout == "t"
|
||||
):
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not (
|
||||
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
1223
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Normal file
1223
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Normal file
File diff suppressed because it is too large
Load Diff
1519
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Normal file
1519
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Normal file
File diff suppressed because it is too large
Load Diff
519
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
Normal file
519
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
Normal file
@@ -0,0 +1,519 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
ROPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
ROPE_CHECK_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_ARCH,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
{F_bs},
|
||||
{F_bsk},
|
||||
{F_bd},
|
||||
{F_bdv},
|
||||
{F_vlayout},
|
||||
{F_rope},
|
||||
{F_pagedkv},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_appendkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
|
||||
FMHA_FWD_APPENDKV_API = """
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """{F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
return fmha_fwd_appendkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVApiTrait:
|
||||
arch: ArchTrait
|
||||
# sync with fmha_fwd_appendkv_traits, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
bs: int # tile size along q seqlen
|
||||
bsk: int # tile size along k seqlen
|
||||
bd: int # tile size along qk gemm unroll
|
||||
bdv: int # tile size along kv gemm unroll
|
||||
vlayout: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
rope: str # key from ROPE_MAP
|
||||
pagedkv: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-"
|
||||
+ f"{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bs} != 0*/"
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bs} == 0"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
# we do not check all the values in a.seqlen_k_ptr
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {self.bd} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {self.bd} == 0"
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {self.bdv} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {self.bdv} == 0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVPipeline:
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_rope: str # key from ROPE_MAP
|
||||
F_pagedkv: str # t/f
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
if self.F_rope != "no":
|
||||
n += f"_{self.F_rope}"
|
||||
if self.F_pagedkv == "t":
|
||||
n += "_pagedkv"
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdAppendKVApiTrait) -> None:
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope],
|
||||
F_bs=trait.bs,
|
||||
F_bsk=trait.bsk,
|
||||
F_bd=trait.bd,
|
||||
F_bdv=trait.bdv,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=hdim,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
|
||||
F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVTileSize:
|
||||
F_bs: int # tile size along q seqlen
|
||||
F_bsk: int # tile size along k seqlen
|
||||
F_bd: int # tile size along qk gemm unroll
|
||||
F_bdv: int # tile size along kv gemm unroll
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" + (
|
||||
"" if self.F_occupancy == -1 else f"_o{self.F_occupancy}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_tile: FmhaFwdAppendKVTileSize
|
||||
F_pipeline: FmhaFwdAppendKVPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs=self.F_tile.F_bs,
|
||||
F_bsk=self.F_tile.F_bsk,
|
||||
F_bd=self.F_tile.F_bd,
|
||||
F_bdv=self.F_tile.F_bdv,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope=ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
|
||||
return FmhaFwdAppendKVApiTrait(
|
||||
arch=self.F_arch,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
bdv=self.F_tile.F_bdv,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
rope=self.F_pipeline.F_rope,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactoryBase:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
|
||||
# applying rotary embedding, so I just use 't' in inter/half pipelines
|
||||
for vlayout, pagedkv in itertools.product(["row"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(FmhaFwdAppendKVPipeline("row", "t", "t", "t", "t", "no", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return KernelComponentFactoryBase.get_pipelines(dtype, hdim)
|
||||
return []
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_appendkv_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str in d.keys():
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in factory.get_pipelines(dtype, hdim):
|
||||
k = FmhaFwdAppendKVKernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(
|
||||
targets, kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(
|
||||
targets, kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + "\n")
|
||||
1145
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Normal file
1145
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Normal file
File diff suppressed because it is too large
Load Diff
799
example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Normal file
799
example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Normal file
@@ -0,0 +1,799 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
import fnmatch
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
BIAS_CHECK_MAP,
|
||||
get_mask_check_map,
|
||||
MODE_MAP,
|
||||
get_mask_map,
|
||||
BIAS_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
DTYPE_BITS,
|
||||
K0_MAX_SUBMAX_MAP,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_ARCH,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse}, //lse
|
||||
{F_pagedkv}, //pagedkv
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
|
||||
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
bias: str #
|
||||
lse: str #
|
||||
pagedkv: str
|
||||
squant: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
sink: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
else:
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_pagedkv: str #
|
||||
F_squant: str #
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_sink: str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
if self.F_bias != "no":
|
||||
n += f"_{self.F_bias}"
|
||||
else:
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_lse == "t":
|
||||
n += "_lse"
|
||||
else:
|
||||
n += "_nlse"
|
||||
|
||||
if self.F_skip == "t":
|
||||
n += "_skip"
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_squant == "t":
|
||||
n += "_squant"
|
||||
else:
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_pagedkv == "t":
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = OrderedDict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
hdim = trait.hdim
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=trait.bn1,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_arch += FMHA_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactoryBase:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t"],
|
||||
["f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
pass # TODO
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
||||
# if pipeline.F_pagedkv == "f":
|
||||
# continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != "no" or pipeline.F_lse == "t":
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not (
|
||||
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_arch=factory.arch,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
70
example/ck_tile/01_fmha/codegen/utils.py
Normal file
70
example/ck_tile/01_fmha/codegen/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
199
example/ck_tile/01_fmha/example_fmha_bwd.cpp
Normal file
199
example/ck_tile/01_fmha/example_fmha_bwd.cpp
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "fmha_bwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "whether do CPU validation or not")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"padded seqlen_q per batch (group mode only). "
|
||||
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k, -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"padded seqlen_k per batch (group mode only). "
|
||||
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n uf or 1 - uniform random float"
|
||||
"\n tf or 2 - trig float")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for dropout random number generator")
|
||||
.insert("drop_offset", "0", "offset for dropout random number generator")
|
||||
.insert(
|
||||
"drop_prefs",
|
||||
"0",
|
||||
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
|
||||
"will not be used")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale = arg_parser.get_float("scale");
|
||||
std::string bias_str = arg_parser.get_str("bias");
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale,
|
||||
bias_str,
|
||||
use_dbias,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
deterministic,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp32")
|
||||
{
|
||||
return run<FmhaBwdFp32>(arg_parser) == bwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaBwdBf16>(arg_parser) == bwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
271
example/ck_tile/01_fmha/example_fmha_fwd.cpp
Normal file
271
example/ck_tile/01_fmha/example_fmha_fwd.cpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "fmha_fwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k (including new key/value), -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_knew",
|
||||
"0",
|
||||
"seqlen_k for new key/value, 0 means not to use this at all; "
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"seqlen_q stride between 2 batches (group-mode optional).\n"
|
||||
"Provide positive strides per-batch to simulate physical padding on Q.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed, same as xformer kv_padding,\n"
|
||||
"must be greater than or equal to s_k")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("qscale",
|
||||
"n",
|
||||
"quant scale:\n"
|
||||
" n or 0, no scale\n"
|
||||
" pt or 1, per-tensor scale\n"
|
||||
" bs or 2, block scale\n"
|
||||
" kvbs or 3, Q per-tensor, K/V per-page block scale\n"
|
||||
" mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("lse", "0", "0 not store lse, 1 store lse")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float"
|
||||
"\n tf or 3 - uniform random float, min max is the max of the type\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for dropout random number generator")
|
||||
.insert("drop_offset", "0", "offset for dropout random number generator")
|
||||
.insert(
|
||||
"drop_prefs",
|
||||
"0",
|
||||
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
|
||||
.insert("num_splits",
|
||||
"1",
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
|
||||
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results")
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
|
||||
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale_s = arg_parser.get_float("scale_s");
|
||||
float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
|
||||
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
std::string bias_str = arg_parser.get_str("bias");
|
||||
std::string qscale_str = arg_parser.get_str("qscale");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
q_eff_lens_per_batch,
|
||||
kv_eff_lens_per_batch,
|
||||
rotary_dim,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale_s,
|
||||
logits_soft_cap,
|
||||
is_v_rowmajor,
|
||||
lse,
|
||||
page_block_size,
|
||||
use_cache_batch_idx,
|
||||
bias_str,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
qscale_str,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp32")
|
||||
{
|
||||
return run<FmhaFwdFp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8fp32")
|
||||
{
|
||||
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "mxfp8")
|
||||
{
|
||||
return run<FmhaFwdMxFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "mxfp4")
|
||||
{
|
||||
return run<FmhaFwdMxFp4>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
588
example/ck_tile/01_fmha/fmha_bwd.hpp
Normal file
588
example/ck_tile/01_fmha/fmha_bwd.hpp
Normal file
@@ -0,0 +1,588 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
|
||||
struct FmhaBwdFp32
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp32>
|
||||
{
|
||||
using QDataType = float;
|
||||
using KDataType = float;
|
||||
using VDataType = float;
|
||||
using GemmDataType = float;
|
||||
using BiasDataType = float;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = float;
|
||||
using OGradDataType = float;
|
||||
using QGradDataType = float;
|
||||
using KGradDataType = float;
|
||||
using VGradDataType = float;
|
||||
using BiasGradDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using GemmDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::half_t;
|
||||
using OGradDataType = ck_tile::half_t;
|
||||
using QGradDataType = ck_tile::half_t;
|
||||
using KGradDataType = ck_tile::half_t;
|
||||
using VGradDataType = ck_tile::half_t;
|
||||
using BiasGradDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using GemmDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using OGradDataType = ck_tile::bf16_t;
|
||||
using QGradDataType = ck_tile::bf16_t;
|
||||
using KGradDataType = ck_tile::bf16_t;
|
||||
using VGradDataType = ck_tile::bf16_t;
|
||||
using BiasGradDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_bwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* dq_ptr;
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
||||
// MQA/GQA..."]
|
||||
//
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqstart_* and seqlen_* pointers must be nullptr.
|
||||
//
|
||||
// Without padding:
|
||||
// (Note: Physical length equals logical length)
|
||||
//
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
||||
// size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
||||
//
|
||||
// Batch mode:
|
||||
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
||||
//
|
||||
const void* seqstart_q_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqstart_k_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
float scale;
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
||||
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
||||
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
||||
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
||||
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
||||
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.batch,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.batch,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
batch_stride_dq,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdOGradDotOKernel>
|
||||
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc,
|
||||
args.batch,
|
||||
args.nhead_q);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
ck_tile::index_t kPadD_,
|
||||
ck_tile::index_t kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_,
|
||||
ck_tile::index_t kN0>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_,
|
||||
ck_tile::index_t kN0>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// Traits that are used to dispatch different kernel implementations for fmha backward
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int seqlen_q;
|
||||
int seqlen_k;
|
||||
int batch;
|
||||
int max_seqlen_q;
|
||||
int max_seqlen_k;
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
int nhead_q;
|
||||
int nhead_k;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
if constexpr(!std::is_same_v<T2, void>)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
|
||||
<< fmha_bwd_convert_dq_get_name_<T2, Arch>() << "@"
|
||||
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_bwd_convert_dq_oneshot_<T2, Arch>(s_, a);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
|
||||
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); });
|
||||
}
|
||||
}
|
||||
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_bwd_launcher
|
||||
{
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
|
||||
ck_tile::index_t dq_acc_splits{0};
|
||||
bool needs_zero_dq_acc{true};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
|
||||
template <typename... Args>
|
||||
float operator()(Args&&... args) const
|
||||
{
|
||||
return run(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
1091
example/ck_tile/01_fmha/fmha_bwd_runner.hpp
Normal file
1091
example/ck_tile/01_fmha/fmha_bwd_runner.hpp
Normal file
File diff suppressed because it is too large
Load Diff
1736
example/ck_tile/01_fmha/fmha_fwd.hpp
Normal file
1736
example/ck_tile/01_fmha/fmha_fwd.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2306
example/ck_tile/01_fmha/fmha_fwd_runner.hpp
Normal file
2306
example/ck_tile/01_fmha/fmha_fwd_runner.hpp
Normal file
File diff suppressed because it is too large
Load Diff
179
example/ck_tile/01_fmha/generate.py
Normal file
179
example/ck_tile/01_fmha/generate.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = "fmha_"
|
||||
handlers = dict(
|
||||
[
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__,
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="gfx9,gfx950",
|
||||
required=False,
|
||||
help="list of GPU targets, separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--api",
|
||||
default="fwd",
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n"
|
||||
+ " 1: generate more instance to cover all hdim\n"
|
||||
+ " 2: Only generate instance for Flash attention integration\n"
|
||||
+ " 4: Only generate instance for PyTorch integration\n"
|
||||
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
|
||||
+ "eg. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
targets = args.targets.split(",")
|
||||
api_list = args.direction.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
targets,
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
targets,
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
203
example/ck_tile/01_fmha/mask.hpp
Normal file
203
example/ck_tile/01_fmha/mask.hpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
ck_tile::index_t sink;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
ck_tile::index_t sink_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, sink_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
auto found_2 = v.find(',', found_1 + 1);
|
||||
ck_tile::index_t v1 = 0;
|
||||
ck_tile::index_t sink = 0;
|
||||
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, sink, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, sink, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
tmp.left = -1;
|
||||
tmp.right = -1;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
std::size_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return static_cast<std::size_t>(seqlen_q) * seqlen_k;
|
||||
std::size_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
BIN
example/ck_tile/01_fmha/misc/gamc.png
Normal file
BIN
example/ck_tile/01_fmha/misc/gamc.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
78
example/ck_tile/01_fmha/quant.hpp
Normal file
78
example/ck_tile/01_fmha/quant.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
{
|
||||
quant_scale_enum type;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == quant_scale_enum::no_scale)
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::kv_blockscale)
|
||||
os << "kvbs";
|
||||
else if(type == quant_scale_enum::mx)
|
||||
os << "mx";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
{
|
||||
quant_scale_info info{quant_scale_enum::no_scale};
|
||||
if(str == "n" || str == "0")
|
||||
{
|
||||
info.type = quant_scale_enum::no_scale;
|
||||
}
|
||||
else if(str == "pt" || str == "1")
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "kvbs" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::kv_blockscale;
|
||||
}
|
||||
else if(str == "mx" || str == "4")
|
||||
{
|
||||
info.type = quant_scale_enum::mx;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
||||
{
|
||||
qsi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
89
example/ck_tile/01_fmha/rotary.hpp
Normal file
89
example/ck_tile/01_fmha/rotary.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#ifndef M_PI // Not there on windows...
|
||||
#define M_PI 3.141592653589793238462643383279502884
|
||||
#endif
|
||||
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with RotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen,
|
||||
ck_tile::index_t rotary_dim,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
// return dummy tensors if we won't apply RoPE at all
|
||||
if(rotary_dim <= 0)
|
||||
{
|
||||
ck_tile::HostTensor<DataType> dummy({1, 1});
|
||||
return std::make_tuple(dummy, dummy);
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen * 2;
|
||||
const ck_tile::index_t num_cols = rotary_dim / 2;
|
||||
|
||||
using std::begin, std::end;
|
||||
|
||||
ck_tile::HostTensor<float> angle({num_rows, num_cols});
|
||||
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
|
||||
|
||||
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::cos(origin_value));
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::sin(origin_value));
|
||||
});
|
||||
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
const ck_tile::HostTensor<DataType>& sin,
|
||||
ck_tile::index_t seqlen_offset,
|
||||
ck_tile::index_t seqlen)
|
||||
{
|
||||
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
|
||||
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
|
||||
|
||||
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen;
|
||||
const ck_tile::index_t num_cols = cos.get_length(1);
|
||||
|
||||
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
|
||||
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
|
||||
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
return std::make_tuple(cos_pt, sin_pt);
|
||||
}
|
||||
23
example/ck_tile/01_fmha/script/benchmark_bwd.sh
Executable file
23
example/ck_tile/01_fmha/script/benchmark_bwd.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
56
example/ck_tile/01_fmha/script/benchmark_fwd.sh
Executable file
56
example/ck_tile/01_fmha/script/benchmark_fwd.sh
Executable file
@@ -0,0 +1,56 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
|
||||
prec="fp16"
|
||||
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_batch_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
|
||||
seqlens_q="1024,768,512,256"
|
||||
seqlens_k="1024,768,512,256"
|
||||
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no physical pad)
|
||||
$EXE $base_group_args
|
||||
|
||||
# low physical pad
|
||||
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
|
||||
# medium physical pad
|
||||
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
|
||||
# high physical pad
|
||||
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
46
example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh
Executable file
46
example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for causal in 0 1 ; do
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Padding benchmark comparisons for v3 (batch mode only)
|
||||
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
|
||||
prec="fp16"
|
||||
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_v3_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
77
example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh
Normal file
77
example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
|
||||
TEST_SPLITKV=0
|
||||
TEST_APPENDKV=0
|
||||
# options:
|
||||
# -s: run splitkv tests
|
||||
# -a: run appendkv tests
|
||||
while getopts ":sa" opt; do
|
||||
case "${opt}" in
|
||||
s)
|
||||
TEST_SPLITKV=1
|
||||
;;
|
||||
a)
|
||||
TEST_APPENDKV=1
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
local CACHE_BATCH_IDX="0"
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
||||
fi
|
||||
|
||||
for prec in "fp16"; do
|
||||
for mode in 1 0 ; do
|
||||
for perm in 0 1 ; do
|
||||
for vlayout in "r" "c" ; do
|
||||
for batch in 1 4; do
|
||||
for head in 1; do
|
||||
for h_k in 1; do
|
||||
for q_seq in 128 512 ; do
|
||||
for kv_seq in 128 1024; do
|
||||
for hdim in 32 64 128 256; do #256
|
||||
for lse in 0 1 ; do
|
||||
for bias in "e" ; do
|
||||
for p_drop in 0.0 0.2; do # 0.0
|
||||
for mask in "t:2,0,4" "b:1,0,2"; do
|
||||
for num_splits in $NUM_SPLITS ; do
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ; done
|
||||
}
|
||||
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
|
||||
set +x
|
||||
52
example/ck_tile/01_fmha/script/run_full_test.sh
Executable file
52
example/ck_tile/01_fmha/script/run_full_test.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
#
|
||||
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
|
||||
#
|
||||
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
|
||||
# input arguments:
|
||||
# environment tag : a string describing the specifics of your test environment
|
||||
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
|
||||
# host name : $hostname
|
||||
# gpu architecture: e.g., gfx90a, or gfx942, etc.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
#get the command line arguments:
|
||||
export env_type=$1
|
||||
echo 'Environment type: ' $env_type
|
||||
export branch=$2
|
||||
echo 'Branch name: ' $branch
|
||||
export host_name=$3
|
||||
echo 'Host name: ' $host_name
|
||||
export GPU_arch=$4
|
||||
echo 'GPU_arch: ' $GPU_arch
|
||||
|
||||
function print_log_header(){
|
||||
rm -f $1;
|
||||
echo 'On branch ' $3 &> $1;
|
||||
echo 'Node name: ' $4 >> $1;
|
||||
#get GPU_arch and number of compute units from rocminfo
|
||||
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
|
||||
rocminfo | grep "Compute Unit:" >> $1;
|
||||
hipcc --version | grep -e 'HIP version' >> $1;
|
||||
echo 'Environment type: ' $2 >> $1;
|
||||
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
|
||||
}
|
||||
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
print_log_header $fmha_fwd_log $env_type $branch $host_name
|
||||
time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
|
||||
|
||||
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
|
||||
print_log_header $fmha_bwd_log $env_type $branch $host_name
|
||||
time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
|
||||
|
||||
93
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
Executable file
93
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
Executable file
@@ -0,0 +1,93 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_bwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=${GPU_arch:-""}
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1'
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
test_h_s_mask() {
|
||||
run_exe -b=1 -h=4 -h_k=2 -s=259 $@
|
||||
run_exe -b=2 -h=2 -s=516 -s_k=253 $@
|
||||
run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@
|
||||
run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@
|
||||
run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@
|
||||
run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@
|
||||
}
|
||||
|
||||
set -x
|
||||
# main tests
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "a" ; do
|
||||
for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# additional cases
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
271
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
Executable file
271
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
Executable file
@@ -0,0 +1,271 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
|
||||
TEST_SPLITKV=0
|
||||
TEST_APPENDKV=0
|
||||
# options:
|
||||
# -s: run splitkv tests
|
||||
# -a: run appendkv tests
|
||||
while getopts ":sa" opt; do
|
||||
case "${opt}" in
|
||||
s)
|
||||
TEST_SPLITKV=1
|
||||
;;
|
||||
a)
|
||||
TEST_APPENDKV=1
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
local CACHE_BATCH_IDX="0"
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
||||
fi
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for mode in 1 0 ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for lse in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for num_splits in $NUM_SPLITS ; do
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for scale in 1 2; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 128 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_fp16_appendkv_tests() {
|
||||
for s in $(seq 63 1 65) ; do
|
||||
for s_k in 65 129 ; do
|
||||
for s_knew in 0 64 $s_k ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for ri in 0 1 ; do
|
||||
for rdim in 0 16 32 $hdim ; do
|
||||
for page_block_size in 0 128 ; do
|
||||
for cache_batch_idx in 0 1 ; do
|
||||
|
||||
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_padding_smoke_tests() {
|
||||
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
|
||||
local prec="fp16"
|
||||
|
||||
# Batch mode: padding via effective lengths (exclude PAD)
|
||||
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
|
||||
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Group mode: padding via physical stride along seqlen
|
||||
local seqlens_q="1024,768,512,256"
|
||||
local seqlens_k="1024,768,512,256"
|
||||
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low physical pad
|
||||
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
# medium physical pad
|
||||
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
# high physical pad
|
||||
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
}
|
||||
|
||||
run_padding_basic_boundary_tests() {
|
||||
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
|
||||
local prec
|
||||
local perm
|
||||
|
||||
# Group mode: Q&K padded with per-batch different strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
|
||||
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# slightly larger, uneven padding strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
|
||||
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# only K padded; Q unpadded
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
|
||||
-s=55 -s_k=256 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# use cu_seqlen overrides to skip tail PAD
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
|
||||
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
|
||||
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# no padding (equal), mixed Q/KV, all len=1
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# highly variable logical lengths
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
|
||||
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
|
||||
-s=256,129 -s_k=256,129 -s_kpad=256 \
|
||||
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
fi
|
||||
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
93
example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Executable file
93
example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Executable file
@@ -0,0 +1,93 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
set -x
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
||||
|
||||
# window_size[2,0], sink_size = 2
|
||||
|
||||
# x=1/y=3
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
||||
|
||||
# x=4/y=1
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
||||
|
||||
# x=4/y=-1
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
||||
|
||||
# x=-1/y=5
|
||||
|
||||
# * * * * * * * * * * * *
|
||||
# * * * * * * * * * * * *
|
||||
# 1 * * * * * 1 * * * * *
|
||||
# 1 1 * * * * 1 1 * * * *
|
||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
||||
# * 1 1 1 * * 1 1 1 1 * *
|
||||
# * * 1 1 1 * 1 1 1 1 1 *
|
||||
# * * * 1 1 1 1 1 * 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
||||
# x=-1/y=8
|
||||
# * * * * * * * * * *
|
||||
# * * * * * * * * * *
|
||||
# 1 * * * * ----> 1 * * * *
|
||||
# 1 1 * * * 1 1 * * *
|
||||
# 1 1 1 * * 1 1 1 * *
|
||||
# 1 1 1 1 * 1 1 1 1 *
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
254
example/ck_tile/01_fmha/utils.hpp
Normal file
254
example/ck_tile/01_fmha/utils.hpp
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
enum class mode_enum
|
||||
{
|
||||
batch = 0,
|
||||
group
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
{
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
for(int32_t seqlen : seqlens)
|
||||
{
|
||||
seqstarts.push_back(seqstarts.back() + seqlen);
|
||||
}
|
||||
assert(seqstarts.size() == seqlens.size() + 1);
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
|
||||
assert(seqlen_min <= seqlen_max);
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
|
||||
if(seqlens[to_decrease] == seqlen_min)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
}
|
||||
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
Reference in New Issue
Block a user