From 42ebffe822bc7d89eeef0160ac461b36b407a025 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 7 Apr 2024 23:11:29 +0000 Subject: [PATCH] 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg --- example/ck_tile/01_fmha/fmha_fwd.cpp | 44 +++- example/ck_tile/01_fmha/fmha_fwd.hpp | 188 +----------------- example/ck_tile/01_fmha/generate.py | 127 +++++++++--- example/ck_tile/01_fmha/mask.hpp | 56 ++++-- example/ck_tile/01_fmha/script/smoke_test.sh | 3 +- include/ck_tile/ops/fmha.hpp | 1 + .../ck_tile/ops/fmha/block/block_masking.hpp | 178 +++++++++++++++-- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 31 ++- .../pipeline/block_fmha_pipeline_enum.hpp | 17 ++ 9 files changed, 380 insertions(+), 265 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5a6afe36f6..0eb17f7b1b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -59,12 +59,13 @@ auto create_args(int argc, char* argv[]) .insert("operm", "1", "permute output") .insert("bias", "0", "add bias or not") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("mask", - "0", - "0: no mask, 1: top-left, 2:bottom-right\n" - "'t:l,r', top-left local-attn with left right size\n" - "'b:l,r', bottom-r local-attn with left right size\n" - "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert( + "mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left sliding window attn with left right size\n" + "'b:l,r', bottom-r sliding window attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") @@ -381,8 +382,9 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_bias, batch_stride_lse, batch_stride_o, - mask.y, - mask.x, + mask.left, + mask.right, + static_cast(mask.type), descale_q * descale_k, descale_v}; }(); @@ -498,12 +500,32 @@ bool run(const ck_tile::ArgParser& arg_parser) else if(mask.type == mask_enum::window_generic) { ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); } else { - ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); } if(lse) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9293201cd2..8ff13cfe13 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -80,177 +80,6 @@ struct FmhaMasks using CausalMask = ck_tile::GenericAttentionMask; }; -#if 0 -// internal API, don't use this directly -template -auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t nhead_k, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t max_seqlen_q, - float scale, - float descale_qk, - float descale_sv, - bool i_perm, - bool o_perm, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x) -{ - constexpr bool is_v_rowmajor = - std::is_same_v; - - assert(nhead % nhead_k == 0); - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, - /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' - /// are 0. - // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_k : nhead_k * seqlen_k; - }(); - const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_k : seqlen_k; - }(); - const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); - const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); - - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - mask_y, - mask_x, - descale_qk, - descale_sv); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask_y, - mask_x, - descale_qk, - descale_sv); - } - }(); - - dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v); - return ck_tile::make_tuple(kargs, grids); -} - -// This is the args from caller to underneath API, different from the kernel -struct fmha_fwd_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* bias_ptr; - void* lse_ptr; - void* o_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; - ck_tile::index_t batch; - ck_tile::index_t nhead; - ck_tile::index_t nhead_k; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t max_seqlen_q; - float scale; - float descale_qk; - float descale_sv; - bool i_perm; - bool o_perm; - ck_tile::index_t mask_y; - ck_tile::index_t mask_x; -}; -#endif - // runtime args, some will passed to karg, some will used to compute grids/blocks struct fmha_fwd_args { @@ -289,8 +118,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; - ck_tile::index_t mask_y; - ck_tile::index_t mask_x; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; float descale_qk; float descale_sv; }; @@ -327,8 +157,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_bias, args.nhead_stride_lse, args.nhead_stride_o, - args.mask_y, - args.mask_x, + args.window_size_left, + args.window_size_right, + args.mask_type, args.descale_qk, args.descale_sv); } @@ -363,8 +194,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_bias, args.batch_stride_lse, args.batch_stride_o, - args.mask_y, - args.mask_x, + args.window_size_left, + args.window_size_right, + args.mask_type, args.descale_qk, args.descale_sv); } @@ -385,6 +217,7 @@ template ; static constexpr bool kHasBias = kHasBias_; static constexpr bool kStoreLse = kStoreLse_; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index e415974480..686dd35d19 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -24,6 +24,16 @@ DTYPE_BITS = { "bf8" : 8 } +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + MASK_MAP = { "no" : "FmhaMasks::NoMask", "causal" : "FmhaMasks::CausalMask", @@ -46,12 +56,17 @@ PIPELINE_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_fp8" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_FP8", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + BOOL_MAP = { "t" : "true", "f" : "false" } -MASKS = ["no", "causal", "generic"] DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder @@ -113,7 +128,8 @@ using fmha_kernel_{F_idx} = fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -149,17 +165,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ MASK_CHECK_MAP = { "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", "generic" : "t.mask_type == mask_enum::window_generic", } +MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ +def get_mask_map(mask : str): + if mask == "generic": + return MASK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_MAP + else: + assert False + return None + +def get_mask_check_map(mask : str): + if mask == "generic": + return MASK_CHECK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + @dataclass class FmhaFwdApiTrait: pipeline_tag : str @@ -193,14 +232,19 @@ class FmhaFwdApiTrait: if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0' + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @property def skcheck(self) -> str: - if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False @property def dcheck(self) -> str: @@ -209,7 +253,7 @@ class FmhaFwdApiTrait: if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0' + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {self.bk0blen} == 0' else: assert False @@ -220,7 +264,7 @@ class FmhaFwdApiTrait: if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0' + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {self.bk0blen} == 0' else: assert False @@ -251,13 +295,17 @@ class FmhaFwdPipeline: n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' if self.F_bias == 't' : n += '_bias' - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' return n class FmhaFwdApiPool: - def __init__(self): + def __init__(self, mask_impl): self.pool = dict() + self.mask_impl = mask_impl def register_traits(self, trait : FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? @@ -278,8 +326,9 @@ class FmhaFwdApiPool: inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], - F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -320,6 +369,7 @@ class FmhaFwdKernel: F_mode : str # value from MODE_MAP F_tile : FmhaFwdTileSize F_pipeline : FmhaFwdPipeline + mask_impl : str @property def template(self) -> str: @@ -347,8 +397,9 @@ class FmhaFwdKernel: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BOOL_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_occupancy = self.F_tile.F_occupancy , - F_mask = MASK_MAP[self.F_pipeline.F_mask], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @@ -403,14 +454,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask)) @@ -423,16 +477,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse kernels - for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]): + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) else: assert False return pipelines gen = list() - api_pool = FmhaFwdApiPool() + api_pool = FmhaFwdApiPool(mask_impl) for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) @@ -443,7 +500,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue @@ -458,24 +515,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None: +def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter) + api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None: +def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter) + _, kernels = get_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") @@ -504,8 +561,26 @@ if __name__ == "__main__": required=False, help="filter out kernels that need to generate, using fnmatch module" ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim" + ) + args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter) + list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter) + write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index d652172ede..526ea5dd04 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -9,11 +9,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +// keep this in sync with ck_tile::GenericAttentionMaskEnum enum class mask_enum { no_mask = 0, - causal_top_left, - causal_bottom_right, + mask_top_left, + mask_bottom_right, window_generic, }; @@ -21,18 +22,19 @@ struct mask_info { mask_enum type; ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right void serialize(std::ostream& os) const { if(type == mask_enum::no_mask) os << "n"; - else if(type == mask_enum::causal_top_left) - os << "tl"; - else if(type == mask_enum::causal_bottom_right) - os << "br"; + else if(type == mask_enum::mask_top_left) + os << "tl(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "br(" << left << ":" << right << ")"; else { - os << "g(" << y << "/" << x << ")"; + os << "g(" << y << ":" << x << ")"; } } static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) @@ -57,22 +59,30 @@ struct mask_info // TODO: some validation if(t == "t") { - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( v0, v1, y_total, x_total, true); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; } else if(t == "b") { - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( v0, v1, y_total, x_total, false); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; } else if(t == "g") { - tmp.y = v0; - tmp.x = v1; + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; } else { @@ -84,15 +94,19 @@ struct mask_info { // should be 0, 1, 2 tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::causal_top_left) + if(tmp.type == mask_enum::mask_top_left) { - tmp.y = seqlen_q; - tmp.x = 1; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; } - else if(tmp.type == mask_enum::causal_bottom_right) + else if(tmp.type == mask_enum::mask_bottom_right) { - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; } } return tmp; diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 012ea42df6..6b7bf8fe41 100644 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1e9acc6d7b..c567e63ddf 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index c256e08e46..39447ca99e 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -7,6 +7,20 @@ namespace ck_tile { +enum struct GenericAttentionMaskEnum +{ + NO_MASK = 0, + + // below enum could be causal, or sliding window + MASK_FROM_TOP_LEFT = 1, + MASK_FROM_BOTTOM_RIGHT = 2, + + // this enum maybe not used by xformer/FA, since it's hard to + // specify left/right window for varlen case. put it here for + // debug purpose + MASK_GENERIC, +}; + // clang-format off /* generic Attention Mask Coordinate use x(horizontal axis), y(vertical axis) to describe mask. @@ -188,6 +202,129 @@ struct GenericAttentionMask index_t y_total, x_total; }; +// clang-format off +namespace impl { + template struct SimplifiedMaskName; + template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version only have 2 variation: masking and non-masking +// This is more friendly to codegen (e.g. need generate less kernel) +// ... with the trade-off that may have more instruction in causal mode +template +struct SimplifiedGenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = i_y + x; // this could be larger than x_total, but it's fine + + return i_x < x_start || i_x >= x_end; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + // TODO: no need to check begin + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + + bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad + bool bottom_left_edge = i_y_end > (i_x + y); + // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask @@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t x_total, bool is_top_left = true) { - index_t x = 0, y = 0; + // TODO: below should all use sgpr arithmetic + index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1; + index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1; - if(is_top_left) - { - if(left_size < 0) - left_size = y_total - 1; - if(right_size < 0) - right_size = x_total - 1; + left_size = left_size < 0 ? left_size_tmp : left_size; + right_size = right_size < 0 ? right_size_tmp : right_size; - x = 1 + right_size; - y = left_size + 1; - } - else - { - if(left_size < 0) - left_size = x_total - 1; - if(right_size < 0) - right_size = y_total - 1; + index_t x_tmp = is_top_left ? 0 : x_total - y_total; + index_t y_tmp = is_top_left ? 0 : y_total - x_total; - x = x_total - y_total + 1 + right_size; - y = y_total - x_total + 1 + left_size; - } + index_t x = 1 + right_size + x_tmp; + index_t y = 1 + left_size + y_tmp; return ck_tile::make_tuple(y, x, y_total, x_total); } + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 98866805a0..a5f7d95d42 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -138,7 +138,9 @@ struct FmhaFwdKernel struct FmhaFwdMaskKargs { - ck_tile::index_t mask_y, mask_x; + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; }; struct FmhaFwdFP8Kargs @@ -217,8 +219,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, float descale_qk, float descale_sv) { @@ -262,8 +265,9 @@ struct FmhaFwdKernel } if constexpr(kHasMask) { - kargs.mask_y = mask_y; - kargs.mask_x = mask_x; + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) { @@ -306,8 +310,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t mask_y, - ck_tile::index_t mask_x, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, float descale_qk, float descale_sv) { @@ -349,8 +354,9 @@ struct FmhaFwdKernel } if constexpr(kHasMask) { - kargs.mask_y = mask_y; - kargs.mask_x = mask_x; + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) { @@ -639,7 +645,12 @@ struct FmhaFwdKernel FmhaMask mask = [&]() { if constexpr(kHasMask) - return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp new file mode 100644 index 0000000000..ae5a88df21 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaPipelineEnum +{ + QRKSVS = 0, + QRKSVS_ASYNC, + QRKSVS_FP8, + QSKSVS, +}; + +} // namespace ck_tile