This commit is contained in:
Juuso Korhonen
2025-10-16 08:57:14 +00:00
46 changed files with 8970 additions and 99 deletions

View File

@@ -20,31 +20,22 @@
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
#define INST_UNIFIED_ATTENTION_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <unified_attention_args::data_type_enum DataType>
struct unified_attention_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
@@ -52,8 +43,8 @@ struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
template <unified_attention_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;

View File

@@ -0,0 +1,238 @@
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(BUILD_TESTING)
# Build instances of all APIs for tests
set(FMHA_FWD_ENABLE_APIS "all")
endif()
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
endif()
endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS}
--optdim 32,64,128,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd
--receipt 3
--optdim 32,64,96,128,256
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
)
# Reduce building time by disabling instances that are not currently used in the gtests
# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of
# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
)
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
# Allow comparing floating points directly in order to check sentinel values
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 ON)
endif()
if(FMHA_FWD_FAST_EXP2)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
endif()
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
target_compile_options(${FMHA_FWD_INSTANCES}
PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
target_compile_options(${FMHA_BWD_INSTANCES}
PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
# "make all/install/check"
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
# "make all/install/check"
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,159 @@
# fused multi-head attention
This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fmha_fwd -j
```
This will result in an executable `build/bin/tile_example_fmha_fwd`
## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
There are 2 template parameters for this kernel template.
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
## codegen
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
## executable
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change)
```
args:
-v weather do CPU validation or not (default:1)
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
Provide positive strides per-batch to simulate physical padding on Q
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
along seqlen, instead of packed, same as xformer kv_padding,
must be greater than or equal to s_k
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-drop_seed seed for random number generator (default:1)
-drop_offset offset for random number generator (default:0)
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fmha_fwd.json)
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
```
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## Padding Examples
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
### hdim
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default)
### group/batch mode
Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below)
### MQA/GQA
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.
### input/output permute, and `b*s*3*h*d`
If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`.
### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
### vlayout
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.
### attention mask
we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right.
Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
![](misc/gamc.png)
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
| mask case| cmdline | FA style | xformer style |
|----------|:-------------:|:-------------:|:-------------:|
| no mask | `-mask=0`(default) | | |
| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` |
| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` |
| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` |
| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` |
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
### dropout
TBD
### sequence padding and variable length support
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.

View File

@@ -0,0 +1,114 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "e" || t == "elementwise")
{
info.type = bias_enum::elementwise_bias;
info.rank_info = std::stoi(v);
if(info.rank_info < 0 || info.rank_info > 2)
throw std::invalid_argument("invalid bias rank: " + str);
}
else if(t == "a" || t == "alibi")
{
info.type = bias_enum::alibi;
info.rank_info = std::stoi(v);
if(info.rank_info < 0 || info.rank_info > 1)
throw std::invalid_argument("invalid bias rank: " + str);
}
else
{
throw std::invalid_argument("invalid bias value: " + str);
}
}
else if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str == "1" || str == "e" || str == "elementwise")
{
info.type = bias_enum::elementwise_bias;
}
else if(str == "2" || str == "a" || str == "alibi")
{
info.type = bias_enum::alibi;
}
else
{
throw std::invalid_argument("invalid bias value: " + str);
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};

View File

@@ -0,0 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp32" : "FmhaFwdFp32",
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32"
}
BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
}
def get_mask_map(mask : str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
}
BOOL_MAP = {
"t" : "true",
"f" : "false",
True : "true",
False : "false",
}

View File

@@ -0,0 +1,633 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
false,
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
#include <iostream>
template<>
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
#include <cstdio>
namespace {{
bool get_num_cus(unsigned& num_cu) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cu = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
else:
return f'{self.bool_expr}'
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
constraint : CppConstraint
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
@staticmethod
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
if 128 in result.keys():
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,929 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Tuple, Dict, Literal, Any
from collections import defaultdict
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.utils import update_file
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile2_{F_idx},
{F_maxq}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_trload},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline<fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false,
({F_dpad} > 0)>>;
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false,
({F_dvpad} > 0)>>;
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
false,
({F_dpad} > 0)>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx},
fmha_bwd_dq_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_dpad},
{F_dvpad},
{F_deterministic},
{F_trload},
{F_maxq},
{F_bn0}>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::kMaxSeqLenQ;
}}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetName();
}}
"""
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
#include <iostream>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
else
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
);
}}
}}
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
float r = -1;
{F_dispatch}
return r;
}}
"""
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
lines = [
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
"{",
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
"}",
]
return '\n'.join(' ' * indent + line for line in lines) + '\n'
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
"""
# M0 size for 1d kernels (dot/convert)
M0_1D = 64
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass(frozen=True)
class FmhaBwdDQDKDVTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
F_bhdq : int # q head_dim
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
F_wm1 : int # warp size along m in gemm1/gemm3
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
max_seq_q : int = 0
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
@dataclass(frozen=True)
class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_dpad : Literal[0, 8 ,1]
F_dvpad : Literal[0, 8 ,1]
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_deterministic : str #
mask_impl : str #
F_trload : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_dpad = self.F_dpad,
F_dvpad = self.F_dvpad,
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = DROPOUT_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_trload = BOOL_MAP[self.F_trload],
F_maxq = self.F_tile.max_seq_q
)
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_dpad : n += f'd{self.F_dpad}'
if self.F_dvpad : n += f'dv{self.F_dvpad}'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_dbias == 't' : n += '_dbias'
else: n += '_ndbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
else: n += '_ndropout'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if dtype == 'fp32' and tr_load == 'f':
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]
else:
return []
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} =
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = M0 = */ 64,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass(frozen=True)
class FmhaBwdOGradDotOKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_spad : str # true/false
F_dvpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy)
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
else: n += '_npad'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic},
{F_bn0}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass(frozen=True)
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
disabled : bool # sometimes this kernel is not used
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dpad == 't' : n += 'd'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
@dataclass(frozen=True)
class FmhaBwdApiTrait:
idx : int # this is not a tunable, but a counter to differentiate symbol
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : int
dtype : str # data type
mode : str # value from MODE_MAP
tile : FmhaBwdDQDKDVTileSize
mask : str
bias : str
dbias : str
dropout : str
spad1d : str # spad for 1d kernels (dot/convert)
dpad : Literal[0, 1, 8]
dvpad : Literal[0, 1, 8]
deterministic : str
mask_impl : str
tr_load : str
@property
def bm0(self) -> int:
return self.tile.F_bm0
@property
def bn0(self) -> int:
return self.tile.F_bn0
@property
def bhdq(self) -> int:
return self.tile.F_bhdq
@property
def bhdv(self) -> int:
return self.tile.F_bhdv
@property
def scheck(self) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.spad1d == 't':
return f'a.seqlen_q % {M0_1D} != 0'
else: # self.spad1d == 'f'
return f'a.seqlen_q % {M0_1D} == 0'
@property
def dcheck(self) -> str:
if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0'
else: return f'a.hdim_q % {self.dpad} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0'
else: return f'a.hdim_v % {self.dvpad} == 0'
@property
def extra_cond(self) -> str:
if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
return "&& (a.seqlen_k <= 256)"
else:
return ""
@property
def convert_dq_bn0(self) -> int:
return self.tile.F_bn0 if self.deterministic == 't' else 0
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
F_dvpad = 't' if self.dvpad else 'f'
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d,
F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
@property
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
@property
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
F_dpad = 't' if self.dpad else 'f'
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
return 'if' if i == 0 else 'else if'
def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str:
inners = ""
i = 0
for trait in traits:
inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad,
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
F_convert_dq_bn0=trait.convert_dq_bn0)
i += 1
return inners
@staticmethod
def trload_sort_key(tf):
return 0 if tf == 't' else 1 # sort 't' before 'f'
@staticmethod
def max_seq_q_sort_key(max_seq_q):
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
@staticmethod
def max_seq_q_cond(max_seq_q: int) -> str:
if max_seq_q == 0:
return 'true /* no seqlen_q limit */'
else:
return f'a.seqlen_q <= {max_seq_q}'
@staticmethod
def dtype_cond(dtype: str) -> str:
return f't.data_type.compare("{dtype}") == 0'
@staticmethod
def hdim_cond(hdim: int) -> str:
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
@property
def api(self) -> str:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true /* no trload requirement */"
}
per_tr_load = ''
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
per_max_seq_q = ''
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;'
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
return result.replace('\n\n', '\n')
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
filter_list = '*@*@*'
filters = filter_list.split('@')
filters.extend(['*'] * (3 - len(filters)))
filter_dot_do_o = filters[0]
filter_convert_dq = filters[1]
filter_dq_dk_dv = filters[2]
# use dict as ordered set
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
api_pool = FmhaBwdApiPool(mask_impl)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
tf = ["t", "f"]
for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product(
tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf):
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
continue
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if tr_load == "t":
continue # tr_load cannot work with dpad or dvpad
else: # tr_load == "f"
# do not generate instance with only 1 of dpad/dvpad being 8
if dpad != dvpad and dpad == 8:
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
continue
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
# Aiter (mha_bwd) integration
elif receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= dpad == dvpad
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [64, 128]
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= bias == 'no'
cond &= dropout == 'no'
cond &= mask == 's_no'
cond &= deterministic == "f"
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
if not t.convert_dq_kernel.disabled:
gen_convert_dq[t.convert_dq_kernel] = True
api_pool.register_dq_dk_dv_traits(t)
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list)
update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api)
for k in kernels_dot_do_o:
update_file(output_dir / k.filename, k.template)
for k in kernels_convert_dq:
update_file(output_dir / k.filename, k.template)
for k in kernels_dq_dk_dv:
update_file(output_dir / k.filename, k.template)
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
filter_list, receipt, mask_impl, optdim_list
)
with file_path.open("a") as f:
for k in kernels_dot_do_o:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
for k in kernels_dq_dk_dv:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
for k in kernels_convert_dq:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,783 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.utils import update_file
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
48 : 48,
64 : 64,
96 : 128,
128: 128,
192: 192,
256: 256
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
else:
return f'{self.bool_expr}'
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
tr_load : str
constraint : CppConstraint
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag in ['qr_async', 'qr_async_trload']:
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr', 'qs']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
def seqtune(self, max_bm0 : int) -> str:
if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/'
else:
return f'a.seqlen_q <= {self.bm0}'
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)'
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
elif self.pipeline_tag in ['qr', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
elif self.pipeline_tag == 'qr_async_trload':
if self.skpad == 't' : return 'true'
else: return 'true'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
F_trload : str # true/false
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true"
}
per_tr_load =str()
for tr_load in ["t", "f"]:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
max_bm0 = max((t.bm0 for t in traits), default=0)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_trload = BOOL_MAP[self.F_pipeline.F_trload])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp32':
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
}
elif dtype == 'fp16' or dtype == 'bf16':
return {
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
elif dtype == 'fp8' or dtype == 'fp8bf16':
return {
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
elif dtype == 'fp8fp32':
return {
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp32']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
elif dtype in ['fp16', 'bf16']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256 and hdim_v == 256:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f":
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't'))
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'bf8']:
# TODO
None
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
if (128, 128) in result.keys():
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory
for dtype in FWD_DTYPE_MAP.keys():
d = factory.get_hdim_tile_size_dict(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0'
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if (hdim, hdim_v) == (192, 128):
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
continue
if dtype != 'fp32':
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
continue
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= mode == 'batch'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= pipeline.F_vlayout == 'row'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
elif receipt == 888:
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
cond &= pipeline.F_vlayout == 'row'
cond &= hdim == 128
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [48, 128]
cond &= mode == 'batch'
cond &= pipeline.F_bias == 'no'
cond &= pipeline.F_lse == 'f'
cond &= pipeline.F_dropout == 'f'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
cond &= pipeline.F_mask == 's_no'
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,376 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")

View File

@@ -0,0 +1,885 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple, Union
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.ops.fmha_fwd import (
FmhaFwdTileSize,
FmhaFwdApiTrait,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
# 160: 160,
256: 256
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
}
FMHA_FWD_SPLITKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
namespace {{
template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
struct instance {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
/*kHasBiasGrad=*/false,
{F_lse},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
fmha_shape,
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait>;
using fmha_pipeline = {F_pipeline}<
fmha_pipeline_problem>;
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-compare"
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}}
}} // anonymous namespace
#pragma clang diagnostic pop
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
run_instance</*kHasUnevenSplits=*/false>(s, a);
}} else {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
}}
}} else {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
}}
}}
template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{
using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
}}
"""
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
namespace {{
template <ck_tile::index_t kLogMaxSplits>
struct instance {{
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
{F_dvpad},
{F_lse},
{F_squant},
kLogMaxSplits,
{F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_mode},
{F_bn1},
fmha_trait>;
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
fmha_pipeline_problem>;
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if (a.num_splits <= 8) {{
instance<3>::run(s, a);
}} else if (a.num_splits <= 16) {{
instance<4>::run(s, a);
}} else if (a.num_splits <= 32) {{
instance<5>::run(s, a);
}} else if (a.num_splits <= 64) {{
instance<6>::run(s, a);
}} else if (a.num_splits <= 128) {{
instance<7>::run(s, a);
}}
}}
template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
{{
using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
}}
"""
FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API="""
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
"""
@dataclass
class FmhaFwdSplitKVApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
mask : str
logits : str
bias : str #
lse : str #
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdSplitKVPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_squant : str #
F_pagedkv : str # t/f
F_mask : str # value from MASK_MAP
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_pagedkv == 't' : n += '_pagedkv'
else: n += '_npagedkv'
return n
@dataclass
class FmhaFwdSplitKVCombinePipeline:
tag : str
F_spad : str # true/false
F_dvpad : str #
F_lse : str #
F_squant : str #
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdSplitKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdSplitKVCombineTileSize:
F_bn1 : int # tile size along v head_dim
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bn1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdSplitKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdSplitKVPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
return FmhaFwdSplitKVApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
logits=self.F_pipeline.F_logits,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
squant=self.F_pipeline.F_squant,
pagedkv=self.F_pipeline.F_pagedkv,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
@dataclass
class FmhaFwdSplitKVCombineKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdSplitKVCombineTileSize
F_pipeline : FmhaFwdSplitKVCombinePipeline
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_mode = MODE_MAP[self.F_mode])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdSplitKVCombineTileSize(32, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
else:
return None
def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
Pipeline = FmhaFwdSplitKVPipeline
Kernel = FmhaFwdSplitKVKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
elif dtype in ['fp8', 'bf8']:
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16, bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= mode == 'batch'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]:
Pipeline = FmhaFwdSplitKVCombinePipeline
Kernel = FmhaFwdSplitKVCombineKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]):
pipelines.append(Pipeline('unused', spad, dvpad, lse, squant))
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant))
else:
assert False
return pipelines
gen = list()
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Aiter(mha_varlen_fwd) integration
if receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
gen.append(k)
return gen
def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
file_path.write_text(api_pool.api)
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (2 - len(filter_list)))
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_splitkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (2 - len(filter_list)))
with file_path.open('a') as f:
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")

View File

@@ -0,0 +1,591 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse}, //lse
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
template<>
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
FMHA_FWD_API="""
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
pagedkv : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_pagedkv : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_pagedkv == 't' : n += '_pagedkv'
else: n += '_npagedkv'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
pagedkv=self.F_pipeline.F_pagedkv,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
# if pipeline.F_pagedkv == 'f':
# continue
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import os.path as path
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)

View File

@@ -0,0 +1,616 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
const auto is_causal = args.get_bool("causal");
if(is_causal)
{
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
}
else
{
mask = mask_info::decode("0", seqlen_q, seqlen_k);
}
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
q_eff_lens = args.get_int_vec("q_eff_lens");
kv_eff_lens = args.get_int_vec("kv_eff_lens");
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
std::vector<int> q_eff_lens;
std::vector<int> kv_eff_lens;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
// Ensure output buffer is zero-initialized so padded regions compare cleanly
o_buf.SetZero();
ck_tile::fmha_fwd_v3_args args{};
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
// Optional cumulative seqlen overrides (exclude PAD)
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
std::vector<ck_tile::index_t> eff;
if(!opt_vec.empty() && opt_vec[0] != -1)
{
eff.assign(opt_vec.begin(), opt_vec.end());
if(eff.size() < static_cast<size_t>(problem.batch))
{
eff.resize(problem.batch, eff.back());
}
}
else
{
eff.assign(problem.batch, fallback);
}
return eff;
};
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
cum_vec.resize(per_batch_vec.size() + 1);
cum_vec[0] = 0;
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
};
if(has_varlen_q)
{
calculate_cumulative(eff_q_vec, cuq_cum);
}
if(has_varlen_k)
{
calculate_cumulative(eff_kv_vec, cukv_cum);
}
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
args.cu_seqlen_q_ptr =
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
: nullptr;
args.cu_seqlen_kv_ptr =
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
: nullptr;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -0,0 +1,132 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
import sys
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import *
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default='fwd',
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory"
)
parser.add_argument(
"-l",
"--list_blobs",
required=False,
help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default='',
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration\n" + \
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
)
parser.add_argument(
"--optdim",
default='-1',
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
"eg. --optdim=32,64,128,256"
)
args = parser.parse_args()
api_list = args.direction.split(',')
filter_list = args.filter.split(',')
filter_list.extend([''] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
else:
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
mask_top_left,
mask_bottom_right,
window_generic,
};
struct mask_info
{
mask_enum type;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
tmp.seqlen_q = seqlen_q;
tmp.seqlen_k = seqlen_k;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
}
else if(t == "t" || t == "b" || t == "g")
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
throw std::invalid_argument("invalid mask value: " + str);
}
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.type = mask_enum::window_generic;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
}
else if(str == "1" || str == "t")
{
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(str == "2" || str == "b")
{
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
return tmp;
}
ck_tile::index_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
return seqlen_q * seqlen_k;
ck_tile::index_t area = 0;
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
{
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
if(x_end > x_start)
{
area += (x_end - x_start);
}
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}

View File

@@ -0,0 +1,20 @@
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
VALID=0
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
nhead=$((2048 / $hdim)) # follow fav2 setup
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
done
done
done

View File

@@ -0,0 +1,53 @@
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
VALID=0
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 64 128 256 ; do
nhead=$((2048 / $hdim)) # follow fav2 setup
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
done
done
done
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
prec="fp16"
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no pad)
$EXE $base_batch_args
# low pad (≈9095% effective)
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
seqlens_q="1024,768,512,256"
seqlens_k="1024,768,512,256"
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no physical pad)
$EXE $base_group_args
# low physical pad
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512

View File

@@ -0,0 +1,42 @@
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
VALID=0
for causal in 0 1 ; do
for prec in "fp16" "bf16" ; do
for hdim in 128 ; do
for perm in 0 ; do
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
done
done
done
done
# Padding benchmark comparisons for v3 (batch mode only)
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
prec="fp16"
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
# baseline (no pad)
$EXE $base_v3_args
# low pad (≈9095% effective)
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320

View File

@@ -0,0 +1,48 @@
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
set -euo pipefail
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run verification tests
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
print_log_header $fmha_fwd_log $env_type $branch $host_name
time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
print_log_header $fmha_bwd_log $env_type $branch $host_name
time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log

View File

@@ -0,0 +1,90 @@
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_bwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=${GPU_arch:-""}
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"}
rm -f $CURR_FAILS_FILE
touch $CURR_FAILS_FILE
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"}
COMMON_ARGS='-v=1'
run_exe() {
set +ex
$EXE $@
local ret=$?
if [ $ret -ne 0 ] ; then
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
fi
set -ex
}
test_h_s_mask() {
run_exe -b=1 -h=4 -h_k=2 -s=259 $@
run_exe -b=2 -h=2 -s=516 -s_k=253 $@
run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@
run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@
run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@
run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@
}
set -x
# main tests
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
done
# additional cases
for hdim in 40 48 72 96 ; do
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
done
set +x
new_fails_count=0
known_fails_count=0
if [ -f $KNOWN_FAILS_FILE ] ; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
echo "Known fail: $line"
known_fails_count=$(($known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$(($new_fails_count + 1))
fi
done < $CURR_FAILS_FILE
else
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
echo "No known fails file, all fails ($new_fails_count) are new:"
cat $CURR_FAILS_FILE
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $(($new_fails_count != 0))

View File

@@ -0,0 +1,281 @@
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"}
rm -f $CURR_FAILS_FILE
touch $CURR_FAILS_FILE
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"}
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
TEST_SPLITKV=0
TEST_APPENDKV=0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while getopts ":sa" opt; do
case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done
run_exe() {
set +ex
$EXE $@
local ret=$?
if [ $ret -ne 0 ] ; then
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
fi
set -ex
}
run_fp16_bf16_tests() {
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in $NUM_SPLITS ; do
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp8bf16_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp8fp32_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
run_padding_smoke_tests() {
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
local prec="fp16"
# Batch mode: padding via effective lengths (exclude PAD)
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low pad (≈9095% effective)
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Group mode: padding via physical stride along seqlen
local seqlens_q="1024,768,512,256"
local seqlens_k="1024,768,512,256"
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low physical pad
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
}
run_padding_basic_boundary_tests() {
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
local prec
local perm
# Group mode: Q&K padded with per-batch different strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# slightly larger, uneven padding strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# only K padded; Q unpadded
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
-s=55 -s_k=256 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# use cu_seqlen overrides to skip tail PAD
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# no padding (equal), mixed Q/KV, all len=1
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# highly variable logical lengths
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
-s=256,129 -s_k=256,129 -s_kpad=256 \
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
-kname=$KNAME $COMMON_ARGS
done
}
set -x
run_fp16_bf16_tests
run_padding_smoke_tests
run_padding_basic_boundary_tests
run_fp8_tests
run_fp8bf16_tests
run_fp8fp32_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
set +x
new_fails_count=0
known_fails_count=0
if [ -f $KNOWN_FAILS_FILE ] ; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
echo "Known fail: $line"
known_fails_count=$(($known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$(($new_fails_count + 1))
fi
done < $CURR_FAILS_FILE
else
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
echo "No known fails file, all fails ($new_fails_count) are new:"
cat $CURR_FAILS_FILE
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $(($new_fails_count != 0))

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type)
{
switch(data_type)
{
case unified_attention_args::data_type_enum::fp16: return stream << "fp16";
case unified_attention_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> unified_attention(const unified_attention_args& args, const stream_config& config)
{
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct unified_attention_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
// window_size_right == 0).
index_t num_blks;
index_t num_head_q;
index_t num_queries_per_kv;
// TODO window
float scale_s;
float scale;
float scale_k;
float scale_v;
float scale_out;
index_t total_num_q_blocks;
const void* q_ptr;
index_t query_stride_0;
index_t query_stride_1;
const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
index_t stride_k_cache_0;
index_t stride_k_cache_1;
index_t stride_k_cache_2;
index_t stride_k_cache_3;
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
index_t stride_v_cache_0;
index_t stride_v_cache_1;
index_t stride_v_cache_2;
index_t stride_v_cache_3;
void* o_ptr;
index_t output_stride_0;
index_t output_stride_1;
const int32_t* block_tables_ptr;
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* query_start_len_ptr; // [num_seqs+1]
index_t num_seqs; // number of batches for q
};
std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> unified_attention(const unified_attention_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,158 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
#include "unified_attention.hpp"
#include "mask.hpp"
#define INST_unified_attention_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <unified_attention_args::data_type_enum DataType>
struct unified_attention_problem_traits;
template <>
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <unified_attention_args::data_type_enum DataType, bool IsMasking>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
// BLOCK_Q BLOCK_SIZE HEAD_SIZE N1 K1
using unified_attention_block_tile = sequence<128, 128, 128>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
using unified_attention_block_warps = sequence<8, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
funified_attention_warp_gemm_shape,
funified_attention_block_warps,
funified_attention_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
false, // kStoreLSE_
-1 // kBlockPerCu
>;
using funified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using funified_attention_pipeline_problem =
UnifiedAttentionPipelineProblem<typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
funified_attention_shape,
funified_attention_mask,
funified_attention_traits>;
using funified_attention_pipeline = BlockFunified_attentionFwdV3Pipeline<funified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = UnifiedAttentionKernel<funified_attention_pipeline, epilogue>;
};
template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
{
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.o_ptr,
args.num_blks,
args.num_head_q,
args.num_queries_per_kv,
args.scale_s,
args.scale,
args.scale_k,
args.scale_v,
args.scale_out,
args.total_num_q_blocks,
args.query_stride_0,
args.query_stride_1,
args.stride_k_cache_0,
args.stride_k_cache_1,
args.stride_k_cache_2,
args.stride_k_cache_3,
args.stride_v_cache_0,
args.stride_v_cache_1,
args.stride_v_cache_2,
args.stride_v_cache_3,
args.output_stride_0,
args.output_stride_1,
args.block_tables_ptr,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs
);
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, args.total_num_q_blocks);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention_args& args,
const stream_config& config);
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,244 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck_tile/core/container/span.hpp"
enum class mode_enum
{
batch = 0,
group
};
std::ostream& operator<<(std::ostream& stream, mode_enum mode)
{
return stream << (mode == mode_enum::batch ? "batch" : "group");
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
{
std::vector<int32_t> seqstarts = {0};
for(int32_t seqlen : seqlens)
{
seqstarts.push_back(seqstarts.back() + seqlen);
}
assert(seqstarts.size() == seqlens.size() + 1);
return seqstarts;
}
template <typename RandomEngine>
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min, // if not negative, clamp min
int32_t seqlen_max, // if not negative, clamp max
RandomEngine& random_engine)
{
assert(0 < count);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count)
{
using size_type = std::vector<int32_t>::size_type;
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
}
const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease];
++seqlens[to_increase];
}
}
return seqlens;
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int, typename RandomEngine>
auto randint(Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::uniform_int_distribution<Int> dist(low, high);
return dist(random_engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator, typename RandomEngine>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
{
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(random_engine); });
}
/*
* generate missing values in *_val randomly when the number of values is smaller than batch
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
RandomEngine& random_engine)
{
if(mode == mode_enum::batch)
{
ck_tile::index_t q = q_val[0];
ck_tile::index_t k = k_val[0];
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && need_append_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
random_engine);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
ck_tile::index_t q = q_val[idx];
ck_tile::index_t k =
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
ck_tile::index_t kp =
k_pad_val.empty()
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
}
if(idx < batch)
{
auto rem_q =
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
auto rem_k = generate_seqlens(
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
}
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
RandomEngine& random_engine)
{
std::iota(first, last, value);
std::shuffle(first, last, random_engine);
}

View File

@@ -15,15 +15,16 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdV3Kernel
template <typename UnifiedAttentionPipeline_, typename EpiloguePipeline_>
struct UnifiedAttentionKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using UnifiedAttentionPipeline = ck_tile::remove_cvref_t<UnifiedAttentionPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
<<<<<<< HEAD
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
@@ -32,23 +33,28 @@ struct FmhaFwdV3Kernel
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
=======
using QDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
using ODataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::SaccDataType>;
using FmhaMask = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::FmhaMask>;
>>>>>>> 853fa21566f6a1fe4237289c61db772b1bbfeb3f
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDim = UnifiedAttentionPipeline::kPadHeadDim;
// TODO add yjese
static constexpr index_t HEAD_SIZE = FmhaPipeline::HEAD_SIZE;
static constexpr index_t HEAD_SIZE_PADDED = FmhaPipeline::HEAD_SIZE_PADDED;
static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE;
static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED;
// BLOCK_Q = BLOCK_M // num_queries_per_kv
// BLOCK_Q is the block size for q seqlen
static constexpr index_t BLOCK_Q = FmhaPipeline::BLOCK_Q;
// static constexpr index_t BLOCK_M = FmhaPipeline::BLOCK_M;
static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q;
// static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M;
// BLOCK size for K seqlen
static constexpr index_t BLOCK_SIZE = FmhaPipeline::BLOCK_SIZE;
static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE;
// kargs use aggregate initializer, so no constructor will provided
@@ -255,7 +261,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// FmhaPipeline::kN1);
// UnifiedAttentionPipeline::kN1);
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
@@ -263,9 +269,11 @@ struct FmhaFwdV3Kernel
return ck_tile::make_tuple(i_tile_m, i_tile_n);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -352,7 +360,7 @@ struct FmhaFwdV3Kernel
q_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<UnifiedAttentionPipeline::kAlignmentQ>{},
number<1>{});
const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
@@ -390,7 +398,7 @@ struct FmhaFwdV3Kernel
k_ptr,
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3),
number<FmhaPipeline::kAlignmentK>{},
number<UnifiedAttentionPipeline::kAlignmentK>{},
number<1>{});
const auto k_dram_pad = pad_tensor_view(
@@ -423,7 +431,7 @@ struct FmhaFwdV3Kernel
v_ptr,
make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE),
make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3),
number<FmhaPipeline::kAlignmentV>{},
number<UnifiedAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_pad = pad_tensor_view(
@@ -462,7 +470,7 @@ struct FmhaFwdV3Kernel
}();
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
block_tables_ptr,
@@ -479,7 +487,7 @@ struct FmhaFwdV3Kernel
o_ptr,
make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
number<FmhaPipeline::kAlignmentO>{},
number<UnifiedAttentionPipeline::kAlignmentO>{},
number<1>{});
const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED

View File

@@ -0,0 +1,68 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t Headdim>
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
{
if constexpr(Headdim == 48)
return 48;
else if constexpr(Headdim == 96)
return 128;
else if constexpr(Headdim == 160)
return 256;
else if constexpr(Headdim == 192)
return 192;
else if constexpr(is_power_of_two_integer(Headdim))
return Headdim;
else
static_assert(Headdim == 0,
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
struct TileUnifiedAttentionShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen
// static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head)
static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen
static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen
// static constexpr index_t kQKHeaddim =
// BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// // once (or repeately load Q as a whole tile)
// static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t HEAD_SIZE_PADDED = ceil_to_qualified_tile_length<HEAD_SIZE>();
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
}

View File

@@ -4,8 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#define ENABLE_ASM_MARKER 1
@@ -270,21 +269,15 @@ struct UnifiedAttentionPipeline
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
static constexpr ck_tile::index_t kM0 = UnifiedAttentionShape::kM0;
static constexpr ck_tile::index_t kN0 = UnifiedAttentionShape::kN0;
static constexpr ck_tile::index_t kK0 = UnifiedAttentionShape::kK0;
static constexpr ck_tile::index_t kN1 = UnifiedAttentionShape::kN1;
static constexpr ck_tile::index_t kK1 = UnifiedAttentionShape::kK1;
static constexpr ck_tile::index_t kQKHeaddim = UnifiedAttentionShape::kQKHeaddim;
static constexpr ck_tile::index_t kSubQKHeaddim = UnifiedAttentionShape::kSubQKHeaddim;
static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q;
static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE;
static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kPadHeadDim = Problem::kPadHeadDim;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
@@ -308,12 +301,12 @@ struct UnifiedAttentionPipeline
}
}();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv)
{
// create another LDS buffer for p
return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType),
Policy::template GetSmemSize<Problem>() +
kM0 * kN0 * sizeof(PDataType));
BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType));
}
// for debug only
@@ -391,6 +384,7 @@ struct UnifiedAttentionPipeline
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
@@ -411,39 +405,39 @@ struct UnifiedAttentionPipeline
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv));
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<kM0, kN0>());
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
[[maybe_unused]] auto s_lds_window =
make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
make_tile_window(s_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
auto p_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc<kM0, kN0>());
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
[[maybe_unused]] auto p_lds_window =
make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
make_tile_window(p_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
auto o_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<kM0, kN1>());
MakeSimpleLdsDesc<BLOCK_Q, BLOCK_SIZE>());
[[maybe_unused]] auto o_lds_window =
make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});
make_tile_window(o_lds, make_tuple(number<BLOCK_Q>{}, number<BLOCK_SIZE>{}), {0, 0});
auto m_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc1D<kM0>());
MakeSimpleLdsDesc1D<BLOCK_Q>());
[[maybe_unused]] auto m_lds_window =
make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});
make_tile_window(m_lds, make_tuple(number<BLOCK_Q>{}), {0});
const index_t warp_group_id = get_warp_id() / 4;
@@ -550,9 +544,9 @@ struct UnifiedAttentionPipeline
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE);
index_t kv_token_start = seqlen_k_start;
// check early exit if no work to do
@@ -596,11 +590,11 @@ struct UnifiedAttentionPipeline
v_dram_window.init_raw();
// prefetch K tile
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
constexpr index_t k0_loops = 1;
constexpr index_t k1_loops = 1;
static_assert(1 == k0_loops);
static_assert(1 == k1_loops);
static_assert(kN0 == kK1);
// static_assert(BLOCK_SIZE == HEAD_SIZE_PADDED);
constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
static_assert(NumWarpGroups == 2);
@@ -832,21 +826,21 @@ struct UnifiedAttentionPipeline
clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
gemm_0(sp(sp_reg_idx).sp_compute,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.k_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
}
else
{
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kM0, k1_loops * kK1>{}),
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.v_tile,
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kN1, k1_loops * kK1>{}));
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
}
};
@@ -856,21 +850,21 @@ struct UnifiedAttentionPipeline
clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
gemm_0(sp(sp_reg_idx).sp_compute,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k0_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.k_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k0_loops * HEAD_SIZE_PADDED>{}));
}
else
{
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kM0, k1_loops * kK1>{}),
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_Q, k1_loops * HEAD_SIZE_PADDED>{}),
get_slice_tile(kv_tile.v_tile,
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kN1, k1_loops * kK1>{}));
sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{},
sequence<BLOCK_SIZE, k1_loops * HEAD_SIZE_PADDED>{}));
fmha_alu0(number<1>{} - sp_reg_idx);
}
};
@@ -915,7 +909,7 @@ struct UnifiedAttentionPipeline
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
q_origin.at(number<0>{}), kv_token_start, number<BLOCK_Q>{}, number<BLOCK_SIZE>{});
if(need_perpixel_check)
{
set_tile_if(sp(sp_reg_idx).sp_compute,
@@ -1026,7 +1020,7 @@ struct UnifiedAttentionPipeline
cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
Scheduler::schedule(cl_p, number<3>{});
kv_token_start += kN0;
kv_token_start += BLOCK_SIZE;
if(num_total_loop <= ++i_total_loops)
{
result = false;
@@ -1073,7 +1067,7 @@ struct UnifiedAttentionPipeline
Scheduler::schedule(cl_p, number<2>{});
fmha_mask(xdl_SP_p01_reg_idx);
kv_token_start += kN0;
kv_token_start += BLOCK_SIZE;
if(num_total_loop <= ++i_total_loops)
{
result = false;
@@ -1151,7 +1145,7 @@ struct UnifiedAttentionPipeline
fmha_alu0(number<0>{});
fmha_alu_D_upd();
kv_token_start += kN0;
kv_token_start += BLOCK_SIZE;
++i_total_loops;
if(num_total_loop <= i_total_loops)
{

View File

@@ -19,6 +19,7 @@ template <typename QDataType_,
typename OaccDataType_,
typename ODataType_,
typename UnifiedAttentionShape_,
typename FmhaMask_,
typename Traits_>
struct UnifiedAttentionPipelineProblem
{
@@ -39,6 +40,7 @@ struct UnifiedAttentionPipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
using Traits = remove_cvref_t<Traits_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
@@ -46,9 +48,7 @@ struct UnifiedAttentionPipelineProblem
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto BiasEnum = Traits::BiasEnum;