diff --git a/CHANGELOG.md b/CHANGELOG.md index 15fdb09f49..997fb8bb8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". +* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. ### Changed diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 6e7d69281d..9c81207361 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -65,7 +65,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS # there is no corresponding instance for parameters). if(BUILD_TESTING) # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill - list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*) endif() # generate a list of kernels, but not actually emit files at config sta diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index edc0e049c5..4d6900a802 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -76,7 +76,8 @@ using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, {F_dropout}, {F_qscale}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -113,7 +114,7 @@ using fmha_kernel = {F_kernel}; using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -229,9 +230,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; return fmha_fwd_(s, a); }} """ @@ -278,13 +279,14 @@ class FmhaFwdApiTrait: dvpad: str skip: str tr_load: str + sink: str constraint: CppConstraint @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -384,6 +386,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false + F_sink: str # true/false F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -454,6 +457,10 @@ class FmhaFwdPipeline: n += "_trload" else: n += "_ntrload" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -543,6 +550,7 @@ class FmhaFwdApiPool: F_trload=BOOL_MAP[trait.tr_load], F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], F_qscale=QSCALE_MAP[trait.qscale], + F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, @@ -683,6 +691,7 @@ class FmhaFwdKernel: F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), + F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -725,6 +734,7 @@ class FmhaFwdKernel: dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, + sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -957,52 +967,55 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): pipelines = [] if dtype in cls._DT_FP32: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for logits, qscale, mask, bias, sink in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"], + ["f", "t"], ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO pass @@ -1033,13 +1046,14 @@ class KernelComponentFactoryGfx950( ) if dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -1048,15 +1062,15 @@ class KernelComponentFactoryGfx950( and dropout == "f" and skip == "f" ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # qr_async_trload_v3 only supports hdim=hdim_v=128 for now if (hdim, hdim_v) == (128, 128): # qr_async_trload_v3 only supports (generic) causal mask for mask in ["no", "causal"]: pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines @@ -1105,23 +1119,24 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): pipelines = [] if dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 342a71e0b0..9105900fc7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -73,7 +73,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_pagedkv}, kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}>; + {F_occupancy}, + {F_sink}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -117,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -279,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -332,6 +333,7 @@ class FmhaFwdSplitKVApiTrait: dpad: str dvpad: str pagedkv: str + sink: str # sink or not bn1comb: int # tile size along v head_dim of combine kernel @property @@ -339,7 +341,7 @@ class FmhaFwdSplitKVApiTrait: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" - + f"{self.dvpad}-{self.pagedkv}" + + f"{self.dvpad}-{self.pagedkv}-{self.sink}" ) @property @@ -425,6 +427,7 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f + F_sink: str # t/f F_mask: str # value from MASK_MAP @property @@ -485,6 +488,10 @@ class FmhaFwdSplitKVPipeline: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -567,6 +574,7 @@ class FmhaFwdSplitKVApiPool: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -667,6 +675,7 @@ class FmhaFwdSplitKVKernel: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -740,19 +749,23 @@ class KernelComponentFactoryBase: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] + for logits, mask, bias, pagedkv, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], ): - pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -908,6 +921,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -917,6 +931,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1075,6 +1090,7 @@ def write_blobs( lse=kernel.F_pipeline.F_lse, squant=kernel.F_pipeline.F_squant, pagedkv=kernel.F_pipeline.F_pagedkv, + sink=kernel.F_pipeline.F_sink, spad=kernel.F_pipeline.F_spad, skpad=kernel.F_pipeline.F_skpad, dpad=kernel.F_pipeline.F_dpad, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index e6eb893a2f..cdb43c3480 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -65,7 +65,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, {F_pagedkv}, //pagedkv {F_squant}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -100,7 +101,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -129,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -163,12 +164,13 @@ class FmhaFwdApiTrait: dpad: str dvpad: str skip: str + sink: str @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -256,6 +258,7 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false + F_sink: str # true/false @property def name(self) -> str: @@ -320,6 +323,10 @@ class FmhaFwdPipeline: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -363,6 +370,7 @@ class FmhaFwdApiPool: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -480,6 +488,7 @@ class FmhaFwdKernel: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -526,6 +535,7 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + sink=self.F_pipeline.F_sink, ) @@ -539,22 +549,23 @@ class KernelComponentFactoryBase: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv, skip in itertools.product( + for logits, mask, bias, pagedkv, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -678,6 +689,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -687,6 +699,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 002d0a1035..60ba334fc0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -266,6 +266,7 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; @@ -352,6 +353,7 @@ struct fmha_fwd_pagedkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; }; @@ -442,6 +444,7 @@ struct fmha_fwd_splitkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; }; @@ -561,6 +564,7 @@ struct fmha_batch_prefill_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; float p_drop; @@ -613,6 +617,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q, args.p_drop, @@ -663,6 +668,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -824,6 +830,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q); } @@ -869,6 +876,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -935,6 +943,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } else @@ -982,6 +991,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -1142,6 +1152,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -1194,6 +1205,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -1228,7 +1240,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1254,6 +1267,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1280,7 +1294,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1305,6 +1320,7 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1327,6 +1343,7 @@ template @@ -1440,6 +1458,7 @@ struct fmha_fwd_traits bool has_dropout; quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1458,6 +1477,7 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; @@ -1477,6 +1497,7 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index bca4c60bc6..536fcb0692 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -879,6 +879,7 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; + traits.has_sink = mask.sink > 0 ? true : false; traits.has_lse = lse; if constexpr(std::is_same_v>) @@ -1042,6 +1043,7 @@ fwd_result fmha_fwd_run(mode_enum mode, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = mask.sink; args.mask_type = static_cast(mask.type); if constexpr(std::is_same_v>) @@ -1645,7 +1647,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k)); } else { @@ -1657,6 +1659,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); @@ -1666,6 +1669,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 32157a2245..f85b811116 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -25,6 +25,7 @@ struct mask_info ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right + ck_tile::index_t sink; void serialize(std::ostream& os) const { @@ -58,13 +59,14 @@ struct mask_info ck_tile::index_t window_size = std::stoi(v); ck_tile::index_t left_size = -1; ck_tile::index_t right_size = 0; + ck_tile::index_t sink_size = 0; if(window_size > 0) { left_size = window_size / 2; right_size = window_size - 1 - left_size; } auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, t == "xt"); + left_size, right_size, sink_size, y_total, x_total, t == "xt"); tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; tmp.y = r.at(ck_tile::number<0>{}); @@ -79,27 +81,54 @@ struct mask_info { throw std::invalid_argument("invalid mask value: " + str); } - ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); - ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + auto found_2 = v.find(',', found_1 + 1); + ck_tile::index_t v1 = 0; + ck_tile::index_t sink = 0; + // ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation if(t == "t") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_top_left; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); + v0, v1, sink, y_total, x_total, true); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "b") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_bottom_right; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); + v0, v1, sink, y_total, x_total, false); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "g") { @@ -108,6 +137,7 @@ struct mask_info tmp.x = v1; tmp.left = v0; // TODO: don't use this? tmp.right = v1; + tmp.sink = 0; } } else @@ -126,6 +156,7 @@ struct mask_info tmp.x = 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else if(str == "2" || str == "b") { @@ -134,6 +165,7 @@ struct mask_info tmp.x = seqlen_k - seqlen_q + 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else { diff --git a/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh new file mode 100644 index 0000000000..de3bff25ed --- /dev/null +++ b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16"; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for vlayout in "r" "c" ; do + for batch in 1 4; do + for head in 1; do + for h_k in 1; do + for q_seq in 128 512 ; do + for kv_seq in 128 1024; do + for hdim in 32 64 128 256; do #256 + for lse in 0 1 ; do + for bias in "e" ; do + for p_drop in 0.0 0.2; do # 0.0 + for mask in "t:2,0,4" "b:1,0,2"; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done +} + + +set -x + +run_fp16_bf16_tests + +set +x diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 4fbde37cae..456c3986fa 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -39,6 +39,7 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh new file mode 100755 index 0000000000..664c825418 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# TODO: run this script from CK root or build directory +#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi +set -x + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 + +# window_size[2,0], sink_size = 2 + +# x=1/y=3 +# 1 * * * * * * * 1 * * * * * * * +# 1 1 * * * * * * 1 1 * * * * * * +# 1 1 1 * * * * * ----> 1 1 1 * * * * * +# * 1 1 1 * * * * 1 1 1 1 * * * * +# * * 1 1 1 * * * 1 1 1 1 1 * * * +# * * * 1 1 1 * * 1 1 * 1 1 1 * * +# * * * * 1 1 1 * 1 1 * * 1 1 1 * +# * * * * * 1 1 1 1 1 * * * 1 1 1 +# l=2/r=0(tl) l=2/r=0/s=2(tl) + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 + +# x=4/y=1 +# 1 1 1 1 * * * * 1 1 1 1 * * * * +# * 1 1 1 1 * * * 1 1 1 1 1 * * * +# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * +# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * +# * * * * 1 1 1 1 1 1 * * 1 1 1 1 +# l=0/r=3(tl) l=0/r=3/s=2(tl) +# l=3/r=0(br) l=3/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 + +# x=4/y=-1 +# * * 1 1 * * * * 1 1 1 1 * * * * +# * * * 1 1 * * * 1 1 * 1 1 * * * +# * * * * 1 1 * * ----> 1 1 * * 1 1 * * +# * * * * * 1 1 * 1 1 * * * 1 1 * +# * * * * * * 1 1 1 1 * * * * 1 1 +# l=1/r=0(br) l=1/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 + +# x=-1/y=5 + +# * * * * * * * * * * * * +# * * * * * * * * * * * * +# 1 * * * * * 1 * * * * * +# 1 1 * * * * 1 1 * * * * +# 1 1 1 * * * ----> 1 1 1 * * * +# * 1 1 1 * * 1 1 1 1 * * +# * * 1 1 1 * 1 1 1 1 1 * +# * * * 1 1 1 1 1 * 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 +# x=-1/y=8 +# * * * * * * * * * * +# * * * * * * * * * * +# 1 * * * * ----> 1 * * * * +# 1 1 * * * 1 1 * * * +# 1 1 1 * * 1 1 1 * * +# 1 1 1 1 * 1 1 1 1 * +# 1 1 1 1 1 1 1 1 1 1 +# 1 1 1 1 1 1 1 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index b25aec101b..47c47334e7 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -459,7 +459,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -auto create_args() +inline auto create_args() { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3840", "m dimension") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c4f100b36b..78f3a9b0b3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } -std::tuple -parse_gemm_size(ck_tile::ArgParser& arg_parser) +std::tuple inline parse_gemm_size( + ck_tile::ArgParser& arg_parser) { ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index aabbfff3bd..7a4760e1da 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -91,7 +91,7 @@ struct GemmConfigBase { static constexpr bool kPadM = false; static constexpr bool kPadN = false; - static constexpr bool kPadK = false; + static constexpr bool kPadK = true; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index fa5e1f12e3..a0e875448d 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -391,25 +391,18 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if(K % QuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); - } - } ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { - AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize - BQK = 0; // No B quantization + AQK = ck_tile::integer_divide_ceil( + K, QuantGroupSize::kK); // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + AQK = 0; // No A quantization + BQK = ck_tile::integer_divide_ceil( + K, QuantGroupSize::kK); // Group quantization: BQK = K / GroupSize BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index dad31ec637..34c6c6b0ae 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -7,46 +7,46 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -struct GemmConfigBase +struct GemmConfigurationBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; + static constexpr bool PAD_M = true; + static constexpr bool PAD_N = true; + static constexpr bool PAD_K = true; - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; + static constexpr bool PERMUTE_A = false; + static constexpr bool PERMUTE_B = false; - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; + static constexpr bool TRANSPOSE_C = false; + static constexpr bool USE_STRUCTURED_SPARSITY = false; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr int BLOCK_PER_CU = 1; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1; + static constexpr bool PRESHUFFLE = false; + static constexpr bool DOUBLE_SMEM_BUFFER = false; }; -template -struct GemmConfigMemoryInterwave : public GemmConfigBase +template +struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 16; + static constexpr ck_tile::index_t M_TILE = 256; + static constexpr ck_tile::index_t N_TILE = 256; + static constexpr ck_tile::index_t K_TILE = 16; - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; + static constexpr ck_tile::index_t M_WARP = 2; + static constexpr ck_tile::index_t N_WARP = 2; + static constexpr ck_tile::index_t K_WARP = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + static constexpr ck_tile::index_t M_WARP_TILE = 32; + static constexpr ck_tile::index_t N_WARP_TILE = 32; + static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16; - static constexpr bool Persistent = Persistent_; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool PERSISTENT = IsPersistent; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; }; template -struct StreamKGemmTypeConfig +struct StreamKGemmTypeConfiguration { using ADataType = ADataType_; using BDataType = BDataType_; @@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -auto create_args(int argc, char* argv[]) +auto createArgs(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index d18ac2e68a..7442bd33f2 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout) } template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) +auto calculateRtolAtol(const ck_tile::index_t k_dim, + const ck_tile::index_t k_batch, + const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto relative_tolerance = + ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(k_dim, k_batch)); + const auto absolute_tolerance = + ck_tile::get_absolute_threshold( + max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch)); // Calculate error due to multiple WGs working in the same C macro tile - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); + const auto relative_tolerance_split_k = + ck_tile::get_relative_threshold(k_batch); + const auto absolute_tolerance_split_k = + ck_tile::get_absolute_threshold(max_accumulated_value, + k_batch); // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k), + std::max(absolute_tolerance, absolute_tolerance_split_k)); } -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s); + const ck_tile::stream_config& stream_config); -template -std::tuple -invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy) +std::tuple invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory, + ck_tile::DeviceMem& b_k_n_device_memory, + ck_tile::DeviceMem& c_m_n_device_memory, + ck_tile::index_t m_dim, + ck_tile::index_t n_dim, + ck_tile::index_t k_dim, + ck_tile::index_t stride_a, + ck_tile::index_t stride_b, + ck_tile::index_t stride_c, + int warmup_iterations, + int repeat_iterations, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy) { - ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - stride_A, - stride_B, - stride_C}; + ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(), + b_k_n_device_memory.GetDeviceBuffer(), + c_m_n_device_memory.GetDeviceBuffer(), + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c}; - std::tuple ave_time_and_batch; + std::tuple average_time_and_batch; if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } else /*Reduction*/ { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } - return ave_time_and_batch; + return average_time_and_batch; } template -bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, - const ck_tile::HostTensor& c_m_n_ref, - const ck_tile::tuple& rtol_atol, - const char* variant) +bool doVerify(const ck_tile::HostTensor& c_m_n_device_result, + const ck_tile::HostTensor& c_m_n_reference, + const ck_tile::tuple& relative_absolute_tolerances, + const char* variant) { - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_ref, + bool pass = ck_tile::check_err(c_m_n_device_result, + c_m_n_reference, "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); + relative_absolute_tolerances.at(ck_tile::number<0>{}), + relative_absolute_tolerances.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "Relative error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<0>{}) + << " Absolute error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl; std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail") << std::endl; return pass; } -ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy) { if(strategy == "atomic") { @@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string } } -template -int run_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) +int runGemmExampleWithLayouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); + auto [result, arg_parser] = createArgs(argc, argv); if(!result) return -1; - static_assert(!GemmConfig::Preshuffle, "Not implemented"); - static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); - static_assert(!GemmConfig::PermuteA, "Not implemented"); - static_assert(!GemmConfig::PermuteB, "Not implemented"); + static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented"); + static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented"); - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; + using ADataType = typename TypeConfiguration::ADataType; + using BDataType = typename TypeConfiguration::BDataType; + using AccumulatorDataType = typename TypeConfiguration::AccDataType; + using CDataType = typename TypeConfiguration::CDataType; - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t m_dim = arg_parser.get_int("m"); + ck_tile::index_t n_dim = arg_parser.get_int("n"); + ck_tile::index_t k_dim = arg_parser.get_int("k"); - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); + int warmup_iterations = arg_parser.get_int("warmup"); + int repeat_iterations = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); bool flush_cache = arg_parser.get_bool("flush_cache"); - ck_tile::StreamKReductionStrategy reduction_strategy = - get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); + getReductionStrategyValue(arg_parser.get_str("reduction_strategy")); - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout)); + stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout)); + stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{})); - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::HostTensor a_m_k_host( + ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_host( + ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_device_result( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_host); } else if(init_method == 1) { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); + ck_tile::FillMonotonicSeq{}(a_m_k_host); + ck_tile::FillMonotonicSeq{}(b_k_n_host); } else if(init_method == 2) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_host); } else { - a_m_k.SetZero(); - b_k_n.SetZero(); + a_m_k_host.SetZero(); + b_k_n_host.SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + a_m_k_device_memory.ToDevice(a_m_k_host.data()); + b_k_n_device_memory.ToDevice(b_k_n_host.data()); + c_m_n_device_memory.SetZero(); + c_m_n_device_result.SetZero(); + auto [average_time, num_wgs_per_tile] = invokeGemm, + AccumulatorDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_device_memory, + b_k_n_device_memory, + c_m_n_device_memory, + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c, + warmup_iterations, + repeat_iterations, + flush_cache, + reduction_strategy); - auto [ave_time, num_wgs_per_tile] = invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - n_warmup, - n_repeat, - flush_cache, - reduction_strategy); + c_m_n_device_memory.FromDevice(c_m_n_device_result.data()); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K - << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim; + std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim + + sizeof(CDataType) * m_dim * n_dim; + float tflops = static_cast(flop) / 1.E9 / average_time; + float gb_per_sec = num_byte / 1.E6 / average_time; + std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim + << " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << ck_tile::DataTypeTraits::name << " B_Type=" << ck_tile::DataTypeTraits::name << " C_Type=" << ck_tile::DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " - << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time + << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - bool pass = false; // Memory on host to store gpu reference result - ck_tile::HostTensor c_m_n_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_ref.SetZero(); + ck_tile::HostTensor c_m_n_reference( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); + c_m_n_reference.SetZero(); if(arg_parser.get_int("v") == 1) // Validate on the CPU { - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_ref); + ck_tile::reference_gemm( + a_m_k_host, b_k_n_host, c_m_n_reference); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU"); } else if(arg_parser.get_int("v") == 2) // Validate on the GPU { // Memory on device to store gpu reference result - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + ck_tile::DeviceMem c_m_n_gpu_buffer_reference( + c_m_n_reference.get_element_space_size_in_bytes()); + c_m_n_gpu_buffer_reference.SetZero(); + ADataType* d_A = static_cast(a_m_k_device_memory.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_device_memory.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buffer_reference.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + CLayout>( + d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c); + c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data()); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU"); } return pass; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 83795fbf6a..d3ee9fe9c6 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -4,11 +4,11 @@ #include "gemm_utils.hpp" #include "ck_tile/ops/common.hpp" -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream_config) { - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + GemmConfiguration::PERMUTE_A, + GemmConfiguration::PERMUTE_B>; - using TilePartitioner = - ck_tile::StreamKTilePartitioner; + using TilePartitioner = ck_tile:: + StreamKTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; - const auto Run = [&](const auto memory_operation) -> std::tuple { + const auto runKernel = [&](const auto memory_operation) -> std::tuple { // We create the GEMM pipeline without specifying has_hot_loop or tail_num. // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -61,39 +67,39 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, ck_tile::CShuffleEpilogueProblem>; + GemmConfiguration::NUM_WAVE_GROUPS>>; using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); ck_tile::DeviceMem workspace_data(workspace_size); workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); - dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kernel_args)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - if(s.log_level_ > 0) + if(stream_config.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' @@ -109,7 +115,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { @@ -120,45 +126,47 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, std::function preprocess = reset_data_buffers; - float ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); - ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{ave_time, num_wgs_per_tile}; + ck_tile::index_t num_wgs_per_tile = + kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; }; if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } else // We are using ck_tile::StreamKReductionStrategy::Reduction { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } } #include "run_gemm_example.inc" -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +template +int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return runGemmExampleWithLayouts( argc, argv, Row{}, Col{}, Row{}); } else @@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -template