* Add check for zero values * Add static assertions * Remove invalid option '-e' in smoke_test.sh * Use correct path of smoke_test.sh * Avoid zero-sized shared memory array * Add warning comment * Replace expr by integer_divide_ceil() call * Use more readable constant names * Write down assumption as static assertion * Add more diagnostic error messages * Fix wrong BlockWarps when using default pipeline policy * Add more static assertions for A LDS desc * Allow using vector size < 8 for data type fp16/bf16 * Align vector size between DRAM dist & LDS desc * Remove no-longer used func decl * Fix wrong displayed piepline name * Undo policy template changes for tile_example_gemm_basic * Add missing space and make error message stands out * Unify print precision * Add missing include directive <iomanip> * Replace constant 64 by get_warp_size() call * Replace constant 128 by named variable: BankLength * Add kAMBlock/kBNBlock attributes * Allow usig different A/B warp dist for multiple blocks * Add helper function to get warp dist encodings * Add 4x64x4 fp16 warp gemm attribute impl * Complete the A/B warp dist encoding logic * Fix wrong thread mapping for C matrix * Use smaller vector size for small tile * Add static assert to block unsupported warp gemm impl * Extract common code out as helper method * Add 4x64x16 fp16 warp gemm type alias * Add comment to warning developers * Undo WarpGemmAtrributeMfma<> changes * Use more clear static assertion error message * Add trivial wrapper to get warp dstr encodings * Only transpose warp gemm result if it's square * Fix compilation error * Support multi-block warp gemm (on N direction) * Remove duplicated code * Fix output encoding of warp gemm * Fix wrong shape of WarpGemmAtrributeMfmaIterateK<> * Remove unused code * Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4 * Add type config for bf16_t * Add 4x64x16 bf16 warp gemm * Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution * Add 64x4x4 fp16/bf16 warp gemm impl * Add 64x4x16 fp16/bf16 warp gemm * Add static assertion for better error diagnostic * Get Q dram dstr directly form block gemm * Add missing header: fused_moe.hpp * Allow specifying different warp-gemm for gemm0 & gemm1 * Store P matrix into LDS before gemm1 * Fix inconsistant kernel name * Remove constraint on gemm0 & gemm1 block warps * Remove unsupported vector size from checking list * Allow using 4x64x16 warp gemm for gemm0 * Finish policy customization * Finish pipeline modification F# * Use block warps in codegen * Fix wrong rank of m_lds_window origin * Use better distributed tensor * Make P-store earlier * Remove duplicated experssions * Remove unnecessary tile window * Create new files for new splitkv pipeline * Separate old/new pipeline codegen logic * Sync changes form develop * Undo gemm kernel/pipeline changes * Undo gemm example changes * Remove blank lines * Fix typo * Use new warp gemm interface * Fix link error * Fix wrong pipeline tag * Fix more link error * Avoid unnecessary padding * Always use vector load for K * Padding on fastest dimension when necessary * Force padding Q on hdim_q * Set high dimension padding flag to false * Re-format headers * Use warps=<1, 4, 1> for both gemm0 & gemm1 * Fix complilation errors * Remove m/l shuffle logics * Ignore duplicate data when write lse_acc * Use gemm0 block warps as lds tile width * Remove hard-coded numbers * Fix wrong distribution width * Remove unnecessary code * Add s_barrier before writing to LDS * Store Q into LDS before gemm0 * Fix wrong Q tile size * Use simple Q lds descriptor for debuging * Use more realistic Q lds descriptor * Add comment & use better variable name * Make Q lds space not overlapped with others * Remove unnecessary block_tile_reduce_sync() call * Move Q load statements * Move block_sync_lds() right before use * Re-order instructions * Remove necessary lambda expression * Use 8 threads on kMaxSplits direction while doing reduction * Tiny correction for using 8 threads on kMaxSplits direction for combine kernel * Padding num_split direction of o_acc tile window to 4x * Update splitkv combine pipeline design * Add kN1 back to splitkv combine pipeline problem * Fix compilation errors * Add missing template parameter * Fix wrong splitkv combine kernel name * Fix wrong origin * Fix wrong LDS descriptor shape * Fix sync & reduction logics * Remove unnecessary static assertions * Extract tile size computation logics * Make sure we can reuse padding flags in combine kernels * Rename variables * Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<> * Remove unnecessary static assertion * Fix function name typo * Add constraint on kN1 template parameter * Hide K tile loading latency in earlier iteration * Fix wrong splitkv kernel name * Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction * Rename pipeline * Fix wrong pipeline name attribute * Add GetAlignmentQ() for NWarpSShuffle pipeline * Separate Q tile into dram tile & register tile concepts * Remove non-squre warp gemm transpose c type alias * Fallback tile size changes for fmha fwd splitkv * Remove redundant change * Refine naming for the S tile * Use better naming of the S tile dstr (read from lds) * Share Q lds with K lds * Tiny change * Fix with using static_for for passing CI checking --------- Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
fused multi-head attention
This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
build
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fmha_fwd -j
This will result in an executable build/bin/tile_example_fmha_fwd
kernel
The kernel template is fmha_fwd_kernel.hpp, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
There are 3 template parameters for this kernel template.
TilePartitioneris used to map the workgroup to corresponding tile,fmha_fwd_tile_partitioner.hppin this folder served as this purpose.FmhaPipelineis one of the block_tile_pipeline(underinclude/ck_tile/tile_program/block_tile_pipeline) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).EpiloguePipelinewill modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
codegen
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by generate.py script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in FMHA_FWD_KERNEL_BODY variable.
executable
tile_example_fmha_fwd is the example executable, implemented in fmha_fwd.cpp. You can type ./bin/tile_example_fmha_fwd -? to list all the arguments. Below is an example of the output (may subject to change)
args:
-v weather do CPU validation or not (default:1)
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-drop_seed seed for random number generator (default:1)
-drop_offset offset for random number generator (default:0)
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
Example 1: ./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128 will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: ./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234 will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
hdim
Currently we support 32/64/128/256 hdim for fp16/bf16, within which 64/128 is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of qr pipeline (we didn't generate this in generate.py by default)
group/batch mode
Currently we support both batch mode and group mode (or varlen, in FA's term), by setting -mode = 0 or 1. In group mode different kind of attention mask is also supported(see below)
MQA/GQA
By setting -h(nhead for q) and -h_k(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that h % h_K == 0 when you set different numbers.
input/output permute, and b*s*3*h*d
If you look at the kernel argument inside fmha_fwd_kernel.hpp, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support b*h*s*d or b*s*h*d input/output permute. The -iperm=0/1, -operm=0/1 is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test b*s*3*h*d layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper stride_q/k/v value as 3*h*d.
attention bias
Attention bias is supported with the layout of 1*1*s*s(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to b*h*s*s) and bias value in float number.
alibi
alibi is supported
lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting -lse=1
vlayout
We support v matrix in both row-major(seqlen*hdim) and col-major(hdim*seqlen). Since the accumulate(reduce) dimension for V is along seqlen, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the -vlayout=r/c here to switch/test between different layouts.
attention mask
we support causal mask and sliding window attention(swa) mask in both batch and group mode, either from top-left or bottom-right.
Underneath, we unify the mask expression into generic attention mask coordinate, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.

Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
| mask case | cmdline | FA style | xformer style |
|---|---|---|---|
| no mask | -mask=0(default) |
||
| causal mask from top-left | -mask=1 or -mask=t |
-mask=t:-1,0 |
-mask=xt:-1 |
| causal mask from bottom-right | -mask=2 or -mask=b |
-mask=b:-1,0 |
-mask=xb:-1 |
| swa from top-left | -mask=t:3,5 |
-mask=xt:4 |
|
| swa from bottom-right | -mask=b:10,11 |
-mask=xb:16 |
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
dropout
TBD
FP8 experimental support
As described in this blog, we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg -prec=fp8 to the tile_example_fmha_fwd, on a gfx940/941/942 machine and ROCm 6.0+.
Currently we only support -vlayout=c( hdim*seqlen for V matrix) and -squant=1(static quantization) with hdim=128 for fp8 now. Full feature support will come later.