Files
composable_kernel/example/ck_tile/01_fmha/README.md
carlushuang db376dd8a4 introducing ck_tile! (#1216)
* enable gfx940

* switch between intrinsic mfma routines on mi100/200 and mi300

* fix mfma_int8 on MI300

* disable 2 int8 examples on MI300

* Update cmake-ck-dev.sh

* restore gitignore file

* modify Jenkinsfile to the internal repo

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* add code

* fix build wip

* fix xx

* now can build

* naming

* minor fix

* wip fix

* fix macro for exp2; fix warpgemm a/b in transposedC

* unify as tuple_array

* Update the required Python version to 3.9

* Update executable name in test scripts

* re-structure tuple/array to avoid spill

* Merge function templates

* Fix format

* Add constraint to array<> ctor

* Re-use function

* Some minor changes

* remove wrong code in store_raw()

* fix compile issue in transpose

* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'

* let more integral_constant->constant, and formating

* make sure thread_buffer can be tuple/array

* temp fix buffer_store spill

* not using custom data type by default, now we can have ISA-level same code as opt_padding

* fix compile error, fp8 not ready now

* fix fp8 duplicated move/shift/and/or problem

* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode

* fix scratch in fp8 kernel

* update some readme

* fix merge from upstream

* sync with upstream

* sync upstream again

* sync 22

* remove unused

* fix clang-format

* update README of ck_tile example

* fix several issue

* let python version to be 3.8 as minimal

* remove ck_tile example from default cmake target like all/install/check

* remove mistake

* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

* fix some bug in group-mode masking and codegen. update README

* F8 quantization for FMHA forward (#1224)

* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

* Add element function to fmha api

* Adjust P elementwise function

* Fix bug of elementwise op, our elementwise op is not inout

* Add some elementwise op, prepare to quantization

* Let generate.py can generate different elementwise function

* To prevent compiler issue, remove the elementwise function we have not used.

* Remove f8 pipeline, we should share the same pipeline even in f8

* Remove remove_cvref_t

* Avoid warning

* Fix wrong fp8 QK/KV block gemm setting

* Check fp8 rounding error in check_err()

* Set fp8 rounding error for check_err()

* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode

* 1. codgen the f8 api and kernel
2. f8 host code

* prevent warning in filter mode

* Remove not-in-use elementwise function kargs

* Remove more not-in-use elementwise function kargs

* Small refinements in C++ source files

* Use conditional_t<> to simplify code

* Support heterogeneous argument for binary function types

* Re-use already-existing scales<> functor template

* Fix wrong value produced by saturating

* Generalize the composes<> template

* Unify saturates<> implementation

* Fix type errors in composes<>

* Extend less_equal<>

* Reuse the existing template less_equal<> in check_err()

* Add equal<float> & equal<double>

* Rename check_err() parameter

* Rename check_err() parameter

* Add FIXME comment for adding new macro in future

* Remove unnecessary cast to void

* Eliminate duplicated code

* Avoid dividing api pool into more than 2 groups

* Use more clear variable names

* Use affirmative condition in if stmt

* Remove blank lines

* Donot perfect forwarding in composes<>

* To fix compile error, revert generate.py back to 4439cc107d

* Fix bug of p element function

* Add compute element op to host softmax

* Remove element function in api interface

* Extract user parameter

* Rename pscale and oscale variable

* rename f8 to fp8

* rename more f8 to fp8

* Add pipeline::operator() without element_functor

* 1. Remove deprecated pipeline enum
2. Refine host code parameter

* Use quantization range as input

* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description

* Refine description

* prevent early return

* unify _squant kernel name in cpp, update README

* Adjust the default range.

* Refine error message and bias range

* Add fp8 benchmark and smoke test

* fix fp8 swizzle_factor=4 case

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
2024-04-15 19:27:12 -05:00

9.3 KiB

fused multi-head attention

This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.

build

# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh  ../ <arch>  # you can replace this <arch> to gfx90a, gfx942...
make tile_example_fmha_fwd -j

This will result in an executable build/bin/tile_example_fmha_fwd

kernel

The kernel template is fmha_fwd_kernel.hpp, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.

There are 3 template parameters for this kernel template.

  • TilePartitioner is used to map the workgroup to corresponding tile, fmha_fwd_tile_partitioner.hpp in this folder served as this purpose.
  • FmhaPipeline is one of the block_tile_pipeline(under include/ck_tile/tile_program/block_tile_pipeline) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
  • EpiloguePipeline will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.

codegen

To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by generate.py script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in FMHA_FWD_KERNEL_BODY variable.

executable

tile_example_fmha_fwd is the example executable, implemented in fmha_fwd.cpp. You can type ./bin/tile_example_fmha_fwd -? to list all supported args. Below is an example of the output (may subject to change)

args:
          -v    weather do CPU validation or not (default:1)
       -mode    kernel mode. 0:batch, 1:group (default:0)
          -b    batch size (default:2)
          -h    num of head, for q (default:8)
        -h_k    num of head, for k/v, 0 means equal to h (default:0)
                 if not equal to h, then this is GQA/MQA case
          -s    seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
                 total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
        -s_k    seqlen_k, 0 means equal to s (default:0)
          -d    head dim for q, k (default:128)
        -d_v    head dim for v, 0 means equal to d (default:0)
    -scale_s    scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
                 note when squant=1, this value will be modified by range_q/k
    -range_q    per-tensor quantization range of q. used if squant=1. (default:2)
    -range_k    per-tensor quantization range of k. used if squant=1. (default:2)
    -range_v    per-tensor quantization range of v. used if squant=1. (default:2)
    -range_p    per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
    -range_o    per-tensor quantization range of o (p*v). used if squant=1. (default:2)
     -squant    if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
                 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
                 scale_o according to range_q, range_k, range_v, range_p, range_o
      -iperm    permute input (default:1)
                 if true, will be b*h*s*d, else b*s*h*d
      -operm    permute output (default:1)
       -bias    add bias or not (default:0)
       -prec    data type. fp16/bf16/fp8/bf8 (default:fp16)
       -mask    0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
                 't', top-left causal mask, 'b', bottom-r causal mask
                 't:l,r', top-left sliding window attn(swa) with FA style left right size
                 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
                 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
                 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
                 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)

    -vlayout    r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
        -lse    0 not store lse, 1 store lse (default:0)
      -kname    if set to 1 will print kernel name (default:0)
       -init    init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)

Example: ./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128 will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.

support features

Currently we are still in rapid development stage, so more features/optimizations will be coming soon.

hdim

Currently we support 32/64/128/256 hdim for fp16/bf16, within which 64/128 is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of qr pipeline (we didn't generate this in generate.py by default)

group/batch mode

Currently we support both batch mode and group mode (or varlen, in FA's term), by setting -mode = 0 or 1. In group mode different kind of attention mask is also supported(see below)

MQA/GQA

By setting -h(nhead for q) and -h_k(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that h % h_K == 0 when you set different numbers.

input/output permute, and b*s*3*h*d

If you look at the kernel argument inside fmha_fwd_kernel.hpp, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support b*h*s*d or b*s*h*d input/output permute. The -iperm=0/1, -operm=0/1 is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test b*s*3*h*d layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper stride_q/k/v value as 3*h*d.

attention bias

Attention bias is supported with the layout of 1*1*s*s(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to b*h*s*s) and bias value in float number.

lse

For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting -lse=1

vlayout

We support v matrix in both row-major(seqlen*hdim) and col-major(hdim*seqlen). Since the accumulate(reduce) dimension for V is along seqlen, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the -vlayout=r/c here to switch/test between different layouts.

attention mask

we support causal mask and sliding window attention(swa) mask in both batch and group mode, either from top-left or bottom-right. Underneath, we unify the mask expression into generic attention mask coordinate, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.

Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.

mask case cmdline FA style xformer style
no mask -mask=0(default)
causal mask from top-left -mask=1 or -mask=t -mask=t:-1,0 -mask=xt:-1
causal mask from bottom-right -mask=2 or -mask=b -mask=b:-1,0 -mask=xb:-1
swa from top-left -mask=t:3,5 -mask=xt:4
swa from bottom-right -mask=b:10,11 -mask=xb:16

Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.

dropout

TBD

FP8 experimental support

As described in this blog, we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg -prec=fp8 to the tile_example_fmha_fwd, on a gfx940/941/942 machine and ROCm 6.0+.

Currently we only support -vlayout=c( hdim*seqlen for V matrix) and -squant=1(static quantization) with hdim=128 for fp8 now. Full feature support will come later.