fix some bug in group-mode masking and codegen. update README

This commit is contained in:
carlushuang
2024-04-09 19:01:25 +00:00
parent 42ebffe822
commit 89a75a97fa
7 changed files with 148 additions and 76 deletions

View File

@@ -32,7 +32,8 @@ args:
-h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0)
if not equal to h, then this is GQA/MQA case
-s seqlen_q (default:3328)
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
-s_k seqlen_k, 0 means equal to s (default:0)
-d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0)
@@ -45,14 +46,19 @@ args:
-operm permute output (default:1)
-bias add bias or not (default:0)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left, 2:bottom-right (default:0)
't:l,r', top-left local-attn with left right size
'b:l,r', bottom-r local-attn with left right size
'g:y,x', generic attention mask coordinate with y/x size
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
'xt:window_size', xformer style masking from top-left, window_size negative is causal, possitive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, possitive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float (default:1)
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
@@ -63,7 +69,7 @@ Currently we are still in rapid development stage, so more features/optimization
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default)
### group/batch mode
Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen
Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below)
### MQA/GQA
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.
@@ -80,11 +86,22 @@ For training kernels, "log sum exp" need to store out in forward and used in bac
### vlayout
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.
### generic attention mask coordinate
We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention.
### attention mask
we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right.
Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
![](misc/gamc.png)
(more description to be added)
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
| mask case| cmdline | FA style | xformer style |
|----------|:-------------:|:-------------:|:-------------:|
| no mask | `-mask=0`(default) | | |
| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` |
| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` |
| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` |
| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` |
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
### dropout
TBD

View File

@@ -44,7 +44,10 @@ auto create_args(int argc, char* argv[])
"0",
"num of head, for k/v, 0 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d")
@@ -59,21 +62,26 @@ auto create_args(int argc, char* argv[])
.insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert(
"mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left sliding window attn with left right size\n"
"'b:l,r', bottom-r sliding window attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, possitive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, possitive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)\n")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 to use "
"non-deterministic random number as seed")
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");

View File

@@ -228,6 +228,7 @@ class FmhaFwdApiTrait:
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
@@ -238,6 +239,7 @@ class FmhaFwdApiTrait:
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
@@ -500,6 +502,10 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):

View File

@@ -29,9 +29,9 @@ struct mask_info
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "tl(" << left << ":" << right << ")";
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "br(" << left << ":" << right << ")";
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
@@ -47,66 +47,103 @@ struct mask_info
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
if(t == "xt" || t == "xb")
{
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = atoi(v.c_str());
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
tmp.left = left_size;
tmp.right = right_size;
}
else
{
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
assert(0);
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
assert(0);
}
}
}
else
{
// should be 0, 1, 2
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::mask_top_left)
{
auto set_causal_top_left = [&]() {
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(tmp.type == mask_enum::mask_bottom_right)
{
};
auto set_causal_bottom_right = [&]() {
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
};
if(str == "t")
set_causal_top_left();
else if(str == "b")
set_causal_bottom_right();
else
{
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::mask_top_left)
{
set_causal_top_left();
}
else if(tmp.type == mask_enum::mask_bottom_right)
{
set_causal_bottom_right();
}
}
}
return tmp;

View File

@@ -8,23 +8,25 @@ export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
mode=0
# mode=0
# export HIP_VISIBLE_DEVICES=4
for prec in "fp16" "bf16" ; do
# for mode in 1 0 ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in 0 1 ; do
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done
@@ -32,4 +34,4 @@ done
done
done
done
#done
done