From 89a75a97fad27c7f279ba5f984c3ef21ac0495e8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 9 Apr 2024 19:01:25 +0000 Subject: [PATCH] fix some bug in group-mode masking and codegen. update README --- example/ck_tile/01_fmha/README.md | 37 ++++-- example/ck_tile/01_fmha/fmha_fwd.cpp | 28 ++-- example/ck_tile/01_fmha/generate.py | 6 + example/ck_tile/01_fmha/mask.hpp | 125 ++++++++++++------ example/ck_tile/01_fmha/script/smoke_test.sh | 18 +-- .../ck_tile/ops/fmha/block/block_masking.hpp | 5 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 5 +- 7 files changed, 148 insertions(+), 76 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 65ce774531..869a635bf8 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -32,7 +32,8 @@ args: -h num of head, for q (default:8) -h_k num of head, for k/v, 0 means equal to h (default:0) if not equal to h, then this is GQA/MQA case - -s seqlen_q (default:3328) + -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary -s_k seqlen_k, 0 means equal to s (default:0) -d head dim for q, k (default:128) -d_v head dim for v, 0 means equal to d (default:0) @@ -45,14 +46,19 @@ args: -operm permute output (default:1) -bias add bias or not (default:0) -prec data type. fp16/bf16/fp8/bf8 (default:fp16) - -mask 0: no mask, 1: top-left, 2:bottom-right (default:0) - 't:l,r', top-left local-attn with left right size - 'b:l,r', bottom-r local-attn with left right size - 'g:y,x', generic attention mask coordinate with y/x size - + -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, possitive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, possitive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) + -init init method. 0:random int, 1:random float, 2:trig float (default:1) + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. @@ -63,7 +69,7 @@ Currently we are still in rapid development stage, so more features/optimization Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) ### group/batch mode -Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen +Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below) ### MQA/GQA By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. @@ -80,11 +86,22 @@ For training kernels, "log sum exp" need to store out in forward and used in bac ### vlayout We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. -### generic attention mask coordinate -We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention. +### attention mask +we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right. +Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. ![](misc/gamc.png) -(more description to be added) +Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. + +| mask case| cmdline | FA style | xformer style | +|----------|:-------------:|:-------------:|:-------------:| +| no mask | `-mask=0`(default) | | | +| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` | +| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` | +| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` | +| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` | + +Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. ### dropout TBD diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 0eb17f7b1b..944041c255 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -44,7 +44,10 @@ auto create_args(int argc, char* argv[]) "0", "num of head, for k/v, 0 means equal to h\n" "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") .insert("s_k", "0", "seqlen_k, 0 means equal to s") .insert("d", "128", "head dim for q, k") .insert("d_v", "0", "head dim for v, 0 means equal to d") @@ -59,21 +62,26 @@ auto create_args(int argc, char* argv[]) .insert("operm", "1", "permute output") .insert("bias", "0", "add bias or not") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert( - "mask", - "0", - "0: no mask, 1: top-left, 2:bottom-right\n" - "'t:l,r', top-left sliding window attn with left right size\n" - "'b:l,r', bottom-r sliding window attn with left right size\n" - "'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, possitive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, possitive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)\n") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") .insert("seed", "11939", - "random seed used for initializing input tensors. 0 to use " - "non-deterministic random number as seed") + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 686dd35d19..b06279fa9a 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -228,6 +228,7 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' @@ -238,6 +239,7 @@ class FmhaFwdApiTrait: @property def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k % {self.bn0} == 0' @@ -500,6 +502,10 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 526ea5dd04..56fc8b8b1d 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -29,9 +29,9 @@ struct mask_info if(type == mask_enum::no_mask) os << "n"; else if(type == mask_enum::mask_top_left) - os << "tl(" << left << ":" << right << ")"; + os << "t(" << left << ":" << right << ")"; else if(type == mask_enum::mask_bottom_right) - os << "br(" << left << ":" << right << ")"; + os << "b(" << left << ":" << right << ")"; else { os << "g(" << y << ":" << x << ")"; @@ -47,66 +47,103 @@ struct mask_info { std::string t = str.substr(0, found_0); std::string v = str.substr(found_0 + 1); - auto found_1 = v.find(","); - if(found_1 == std::string::npos) + if(t == "xt" || t == "xb") { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); - } - tmp.type = mask_enum::window_generic; - ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation - if(t == "t") - { - tmp.type = mask_enum::mask_top_left; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "b") - { - tmp.type = mask_enum::mask_bottom_right; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "g") - { - tmp.y = v0; - tmp.x = v1; - tmp.left = v0; // TODO: don't use this? - tmp.right = v1; + tmp.left = left_size; + tmp.right = right_size; } else { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } } } else { - // should be 0, 1, 2 - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::mask_top_left) - { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; tmp.y = seqlen_q; tmp.x = 1; tmp.left = -1; tmp.right = 0; - } - else if(tmp.type == mask_enum::mask_bottom_right) - { + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; tmp.y = seqlen_q; tmp.x = seqlen_k - seqlen_q + 1; tmp.left = -1; tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } } } return tmp; diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 6b7bf8fe41..5a90b319a5 100644 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -8,23 +8,25 @@ export CK_WARMUP=0 export CK_REPEAT=1 COMMON_ARGS='-v=1 -warmup=0 -repeat=1' -mode=0 +# mode=0 +# export HIP_VISIBLE_DEVICES=4 for prec in "fp16" "bf16" ; do -# for mode in 1 0 ; do +for mode in 1 0 ; do for perm in 0 1 ; do for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in 0 1 ; do -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done @@ -32,4 +34,4 @@ done done done done -#done +done diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 39447ca99e..7fb1c19b5f 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -280,9 +280,8 @@ struct SimplifiedGenericAttentionMask } else { - // no need to do min/max here, since i_x will never be < 0 or >= x_total - index_t x_start = -y + i_y + 1; // this could be negative, but it's fine - index_t x_end = i_y + x; // this could be larger than x_total, but it's fine + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded return i_x < x_start || i_x >= x_end; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c7e2f3ae4b..80a1dbd756 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -256,11 +256,13 @@ struct BlockFmhaPipelineQRKSVSAsync store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. return o_acc; } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check } auto k_dram_block_window = @@ -388,6 +390,7 @@ struct BlockFmhaPipelineQRKSVSAsync k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) { set_tile_if(