[CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
carlushuang
2024-05-07 22:32:54 +08:00
committed by GitHub
parent 6d073d31bb
commit 851c3ed157
24 changed files with 948 additions and 115 deletions

View File

@@ -30,27 +30,29 @@ args:
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0)
-h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
-s_k seqlen_k, 0 means equal to s (default:0)
-s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:2)
-range_k per-tensor quantization range of k. used if squant=1. (default:2)
-range_v per-tensor quantization range of v. used if squant=1. (default:2)
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias add bias or not (default:0)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
@@ -59,11 +61,11 @@ args:
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
@@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov
### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`