diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4ff7ede765..3b0f4aaf4b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -51,6 +51,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, false, {F_lse}, @@ -88,7 +89,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -123,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -144,6 +145,7 @@ class FmhaFwdApiTrait: bk1 : int # tile size along kv gemm unroll bk0max : int vlayout : str + logits : str mask : str bias : str # lse : str # @@ -157,7 +159,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -214,6 +216,7 @@ class FmhaFwdPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_dropout : str # @@ -235,6 +238,9 @@ class FmhaFwdPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -280,7 +286,7 @@ class FmhaFwdApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -365,6 +371,7 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], @@ -399,6 +406,7 @@ class FmhaFwdKernel: bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, @@ -413,12 +421,12 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { @@ -440,33 +448,33 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b1f9e30178..26753f5ea4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -63,6 +63,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, /*kHasBiasGrad=*/false, {F_lse}, @@ -111,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -265,9 +266,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -308,6 +309,7 @@ class FmhaFwdSplitKVApiTrait: bk0max : int vlayout : str mask : str + logits : str bias : str # lse : str # squant : str # @@ -320,7 +322,7 @@ class FmhaFwdSplitKVApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @property @@ -378,6 +380,7 @@ class FmhaFwdSplitKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_squant : str # @@ -399,6 +402,9 @@ class FmhaFwdSplitKVPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -468,7 +474,7 @@ class FmhaFwdSplitKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -534,6 +540,7 @@ class FmhaFwdSplitKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -567,6 +574,7 @@ class FmhaFwdSplitKVKernel: bk1=self.F_tile.F_bk1, bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, @@ -617,11 +625,11 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { @@ -664,26 +672,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..fb42ed52a6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[]) "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") @@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + std::string squant_str = arg_parser.get_str("squant"); bool squant = [&]() { if(squant_str == "auto") @@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else // fmha_fwd_traits or fmha_splitkv_traits { traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.scale_p = scale_p; args.scale_o = scale_o; + args.logits_soft_cap = logits_soft_cap; + args.stride_bias = (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); args.stride_o = stride_o; @@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 765c221a7b..70a1c40903 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -143,6 +143,8 @@ struct fmha_fwd_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -232,6 +234,8 @@ struct fmha_fwd_splitkv_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -333,6 +337,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -371,6 +376,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -443,6 +449,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.is_gappy, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -485,6 +492,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.cache_batch_idx, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -630,6 +638,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -677,6 +687,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -784,6 +796,7 @@ struct fmha_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -800,6 +813,7 @@ struct fmha_fwd_splitkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index c671463252..0241ed01af 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -47,6 +47,7 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; @@ -94,7 +95,7 @@ struct FmhaFwdKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ @@ -139,6 +140,28 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o; }; + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct FmhaFwdCommonBiasKargs { const void* bias_ptr = nullptr; @@ -242,7 +265,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -260,7 +284,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -287,6 +312,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -343,6 +369,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, @@ -398,6 +425,10 @@ struct FmhaFwdKernel kargs.batch_stride_randval = batch_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -421,6 +452,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -465,6 +497,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -512,6 +545,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -556,6 +590,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -603,6 +638,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -652,6 +688,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -703,6 +740,10 @@ struct FmhaFwdKernel kargs.nhead_stride_randval = nhead_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -727,6 +768,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -765,6 +807,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -806,6 +849,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -844,6 +888,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -1328,6 +1373,7 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + kargs, smem_ptr, dropout); } @@ -1342,6 +1388,7 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + kargs, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 14d0596287..efa9061ad7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -43,6 +43,7 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; @@ -95,7 +96,7 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); #undef _SS_ #undef _TS_ @@ -149,6 +150,28 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc; }; + struct LogitsSoftCapKargs + { + LogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct CommonBiasKargs { const void* bias_ptr = nullptr; @@ -206,7 +229,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t + std::conditional_t, + std::conditional_t> { const int32_t* seqlen_k_ptr; @@ -228,7 +252,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -267,6 +292,7 @@ struct FmhaFwdSplitKVKernel const void* cache_batch_idx, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -323,6 +349,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table or cache_batch_idx + {}, // placeholder for logits_soft_cap reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, @@ -362,6 +389,10 @@ struct FmhaFwdSplitKVKernel { kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -391,6 +422,7 @@ struct FmhaFwdSplitKVKernel bool is_gappy, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -443,6 +475,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), @@ -477,6 +510,10 @@ struct FmhaFwdSplitKVKernel kargs.page_block_size = page_block_size; kargs.is_gappy = is_gappy; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -990,6 +1027,7 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + kargs, kv_l2p_offset, smem_ptr); } @@ -1007,6 +1045,7 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + kargs, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ce80dba5eb..ce9d388693 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -45,15 +45,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,7 +131,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -149,6 +154,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -167,6 +173,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + (void)logits_soft_cap_params; + // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); @@ -662,7 +670,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile @@ -676,6 +685,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -699,6 +709,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, + logits_soft_cap_params, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9a5208c025..069e6243b8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -50,6 +50,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; @@ -98,6 +99,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8a4a925b81..31788f5853 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -47,14 +47,18 @@ struct BlockFmhaPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -128,7 +132,8 @@ struct BlockFmhaPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -147,6 +152,7 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, DropoutType& dropout) const { @@ -165,6 +171,8 @@ struct BlockFmhaPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + (void)logits_soft_cap_params; + // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); @@ -614,7 +622,8 @@ struct BlockFmhaPipelineQRKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -625,6 +634,7 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, DropoutType& dropout) const { @@ -645,6 +655,7 @@ struct BlockFmhaPipelineQRKSVS mask, position_encoding, scale_s, + logits_soft_cap_params, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index cdced9ba2e..54e53d230c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -53,13 +53,17 @@ struct BlockFmhaPipelineQRKSVSAsync // only need special care about seq_k padding (oob need set -INF of p instead of zero) static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && Problem::kPadHeadDimV == true); - static constexpr bool kPadSeqLenQ = true; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert(CK_TILE_FMHA_FWD_FAST_EXP2 || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -153,7 +157,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -172,6 +177,7 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, DropoutType& dropout) const { @@ -192,6 +198,7 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + (void)logits_soft_cap_params; const float logits_cap = 30.0f; const float logits_cap_rev = 0.0333333f; // const float logits_cap_scale = scale_s * rcp(logits_cap * log2e_v); @@ -748,7 +755,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename LogitsSoftCapParams> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -759,6 +767,7 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const LogitsSoftCapParams& logits_soft_cap_params, void* smem_ptr, DropoutType& dropout) const { @@ -779,6 +788,7 @@ struct BlockFmhaPipelineQRKSVSAsync mask, position_encoding, scale_s, + logits_soft_cap_params, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 8d2d848558..4530b58d85 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -13,6 +13,7 @@ template 1 or fwd training is running */ @@ -51,6 +54,7 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_;