diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 77b63a0c83..47cf6b3ad4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -8,21 +8,13 @@ import fnmatch import itertools from pathlib import Path from typing import List, Optional, Tuple, Dict, Literal +from collections import defaultdict from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file -BWD_DQDKDV_PIPELINE_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", -} - -BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", -} - FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -56,8 +48,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}; +using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - {F_skpad}, + false, {F_dpad}>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - {F_skpad}, + false, {F_dvpad}>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = @@ -115,13 +107,10 @@ using fmha_bwd_dq_dk_dv_kernel_{F_idx} = using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, - {F_pipeline_enum}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_bias}, {F_dbias}, - {F_spad}, - {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; @@ -195,15 +184,18 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>; + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; r = fmha_bwd_(s, a); return r; }} """ +# M0 size for 1d kernels (dot/convert) +M0_1D = 64 + # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) @@ -249,8 +241,6 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # F_dpad : str # F_dvpad : str # F_bias : str # @@ -259,7 +249,6 @@ class FmhaBwdDQDKDVKernel: F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP F_deterministic : str # - F_pipeline : str # mask_impl : str # @property @@ -293,8 +282,6 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], F_bias = BIAS_MAP[self.F_bias], @@ -304,21 +291,18 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], - F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + ) @property def name(self) -> str: def pad_name() -> str: n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' if self.F_dpad == 't' : n += 'd' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() - n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name if pn != '' : n += f'_{pn}' else: n += '_npad' @@ -347,20 +331,15 @@ class FmhaBwdDQDKDVKernel: return self.name + ".cpp" # TODO: design a more practical way to do it -# this is current supported tile size & pipeline. +# this is current supported tile size. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - # "kr_ktr_vr_iglp", "kr_ktr_vr"], - '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"] + '32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + '64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + '128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + '256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), } else: return None @@ -375,7 +354,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 64, + /* BlockSize = M0 = */ 64, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; @@ -580,7 +559,6 @@ class FmhaBwdConvertQGradKernel: @dataclass(frozen=True) class FmhaBwdApiTrait: idx : int # this is not a tunable, but a counter to differentiate symbol - pipeline : str # sync with fmha_bwd_traits<>, to generate fallback calls hdim : int dtype : str # data type @@ -590,9 +568,7 @@ class FmhaBwdApiTrait: bias : str dbias : str dropout : str - spad : str - spad1 : str # spad for dot/convert kernel - skpad : str + spad1d : str # spad for 1d kernels (dot/convert) dpad : str dvpad : str deterministic : str @@ -611,24 +587,14 @@ class FmhaBwdApiTrait: def bhdv(self) -> int: return self.tile.F_bhdv - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return 'a.seqlen_q % 64 == 0' - @property - def skcheck(self) -> str: + def scheck(self) -> str: if self.mode == 'group': return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' + elif self.spad1d == 't': + return f'a.seqlen_q % {M0_1D} != 0' + else: # self.spad1d == 'f' + return f'a.seqlen_q % {M0_1D} == 0' @property def dcheck(self) -> str: @@ -647,14 +613,14 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 - return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1, + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, - F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, - F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) + F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -664,48 +630,46 @@ class FmhaBwdApiTrait: return 2 return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic) class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list)) self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + @staticmethod + def if_(i: int) -> str: + return 'if' if i == 0 else 'else if' + + def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: + inners = "" + i = 0 + for trait in traits: + inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + i += 1 + return inners + @property def api(self) -> str: per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + for i, dtype in enumerate(self.dq_dk_dv_pool): per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]): traits=self.dq_dk_dv_pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + inners = self._api_innders(traits) + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners) + per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: # empty string we add some ignore to suppress warning in api per_dtypes += ' (void)t ; (void)s ; (void)a;' @@ -730,21 +694,16 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d is None: continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] + for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + tile = d[hdim_str] hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if (spad1 == "f") and (spad == "t" or mode == "group"): + if (mode == "group") and (spad1d == "f"): continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue if ("wg32" in dropout): continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue @@ -808,13 +767,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) - (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) for k in kernels_dot_do_o: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) for k in kernels_convert_dq: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) for k in kernels_dq_dk_dv: - (output_dir / k.filename).write_text(k.template) + update_file(output_dir / k.filename, k.template) def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: diff --git a/example/ck_tile/01_fmha/codegen/utils.py b/example/ck_tile/01_fmha/codegen/utils.py new file mode 100644 index 0000000000..e3bbb18c42 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 9179dbd9be..c999cf750e 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -357,31 +357,25 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) template struct fmha_bwd_dq_dk_dv_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - using FmhaDropout = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; }; template diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 188cebaabc..9f3c996873 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/container/thread_buffer.hpp" #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" @@ -74,6 +75,7 @@ #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" +#include "ck_tile/core/utility/print.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/static_counter.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index f7f9489f4c..7511413bba 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { @@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1> { return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("pass_through{"); - - // - printf("up_lengths_:"); - print(up_lengths_); - - // - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const pass_through& pt) +{ + printf("pass_through{"); + + printf("up_lengths_: "); + print(pt.get_upper_lengths()); + + printf("}"); +} + template ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("left_pad_length_: "); - print(left_pad_length_); - printf(", "); - - // - printf("right_pad_length_: "); - print(right_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const pad& p) +{ + printf("pad{"); + printf("up_lengths_: "); + print(p.up_lengths_); + printf(", left_pad_length_: "); + print(p.left_pad_length_); + printf(", right_pad_length_: "); + print(p.right_pad_length_); + printf("}"); +} + template struct left_pad { @@ -330,24 +326,20 @@ struct left_pad // It's up to runtime to check the padding length should be multiple of vector length return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("left_pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("left_pad_length_: "); - print(left_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const left_pad& lp) +{ + printf("left_pad{"); + printf("up_lengths_: "); + print(lp.up_lengths_); + printf(", left_pad_length_: "); + print(lp.left_pad_length_); + printf("}"); +} + template struct right_pad : public base_transform<1, 1> { @@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1> // It's up to runtime to check the padding length should be multiple of vector length return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("right_pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("right_pad_length_: "); - print(right_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const right_pad& rp) +{ + printf("right_pad{"); + printf("up_lengths_: "); + print(rp.up_lengths_); + printf(", right_pad_length_: "); + print(rp.right_pad_length_); + printf("}"); +} + // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] // UpLengths and Coefficients can be either of the followings: // 1) Tuple of index_t, which is known at run-time, or @@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()> return ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("embed{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("coefficients_: "); - print(coefficients_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const embed& e) +{ + printf("embed{"); + printf("up_lengths_: "); + print(e.up_lengths_); + printf(", coefficients_: "); + print(e.coefficients_); + printf("}"); +} + template struct lambda_merge_generate_MagicDivision_calculate_magic_divisor { @@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("merge_v2_magic_division{"); - - // - printf("low_lengths_ "); - print(low_lengths_); - printf(", "); - - // - printf("up_lengths_ "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division& m) +{ + printf("merge_v2_magic_division{"); + printf("low_lengths_: "); + print(m.low_lengths_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + // Implementation of "merge" transformation primitive that uses division and mod. It is supposed to // be used for low_lengths that are known at compile time and are power of 2, otherwise performance // will be very bad @@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("Merge_v3_direct_division_mod{"); - - // - printf("low_lengths_ "); - print(low_lengths_); - printf(", "); - - // - printf("low_lengths_scan_ "); - print(low_lengths_scan_); - printf(", "); - - // - printf("up_lengths_ "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod& m) +{ + printf("merge_v3_division_mod{"); + printf("low_lengths_: "); + print(m.low_lengths_); + printf(", low_lengths_scan_: "); + print(m.low_lengths_scan_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + template struct unmerge : public base_transform<1, UpLengths::size()> { @@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()> return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("unmerge{"); - - // - printf("up_lengths_"); - print(up_lengths_); - printf(", "); - - // - printf("up_lengths_scan_"); - print(up_lengths_scan_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const unmerge& u) +{ + printf("unmerge{"); + printf("up_lengths_: "); + print(u.up_lengths_); + printf(", up_lengths_scan_: "); + print(u.up_lengths_scan_); + printf("}"); +} + template struct freeze : public base_transform<1, 0> { @@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("freeze{"); - - // - printf("low_idx_: "); - print(low_idx_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const freeze& f) +{ + printf("freeze{"); + printf("low_idx_: "); + print(f.low_idx_); + printf("}"); +} + // insert a dangling upper dimension without lower dimension template struct insert : public base_transform<0, 1> @@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("insert{"); - - // - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const insert& i) +{ + printf("insert{"); + printf("up_lengths_: "); + print(i.up_lengths_); + printf("}"); +} + // replicate the original tensor and create a higher dimensional tensor template struct replicate : public base_transform<0, UpLengths::size()> @@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()> return ck_tile::is_known_at_compile_time::value; } - CK_TILE_HOST_DEVICE void print() const - { - printf("replicate{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - - printf("}"); - } - // UpLengths up_lengths_; }; +template +CK_TILE_HOST_DEVICE static void print(const replicate& r) +{ + printf("replicate{"); + printf("up_lengths_: "); + print(r.up_lengths_); + printf("}"); +} + template struct slice : public base_transform<1, 1> { @@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1> ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } +}; - CK_TILE_HOST_DEVICE void print() const - { - printf("slice{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("slice_begin_: "); - print(slice_begin_); - printf(", "); - - // - printf("slice_end_: "); - print(slice_end_); - - printf("}"); - } // namespace ck -}; // namespace ck +template +CK_TILE_HOST_DEVICE static void print(const slice& s) +{ + printf("slice{"); + printf("up_lengths_: "); + print(s.up_lengths_); + printf(", slice_begin_: "); + print(s.slice_begin_); + printf(", slice_end_: "); + print(s.slice_end_); + printf("}"); +} /* * \brief lower_idx = upper_idx % modulus. @@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("Modulus{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const modulo& m) +{ + printf("modulo{"); + printf("modulus_: "); + print(m.modulus_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + // 2D XOR, NOTE: "xor" is a keyword template struct xor_t : public base_transform<2, 2> @@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2> return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("xor_t{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const xor_t& x) +{ + printf("xor_t{"); + printf("up_lengths_: "); + print(x.up_lengths_); + printf("}"); +} + template struct offset : public base_transform<1, 1> { @@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1> return ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("offset{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("offset_length_: "); - print(offset_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const offset& o) +{ + printf("offset{"); + printf("up_lengths_: "); + print(o.up_lengths_); + printf(", offset_length_: "); + print(o.offset_length_); + printf("}"); +} + template struct indexing : public base_transform<1, 1> { @@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1> return ck_tile::is_known_at_compile_time::value && IndexingAdaptor::is_known_at_compile_time(); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("embed{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const indexing& i) +{ + printf("indexing{"); + printf("up_lengths_: "); + print(i.up_lengths_); + printf(", iadaptor_: "); + print(i.iadaptor_); + printf("}"); +} + //******************************************************************************************************* template diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 8a3de3e5e0..1f6c389090 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -77,6 +77,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { @@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D +CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D&) +{ + using PatternType = TileDistributionEncodingPattern2D; + + printf("TileDistributionEncodingPattern2D: ", + BlockSize, + YPerTile, + XPerTile, + VecSize, + tile_distribution_pattern_to_string(DistributionPattern)); + printf("{: <%d, %d, %d>, : <%d, %d>}\n", + PatternType::Y0, + PatternType::Y1, + PatternType::Y2, + PatternType::X0, + PatternType::X1); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 96df9d70f7..ab42ec8617 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() #endif } +/// Helper function to convert address space enum to string +CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space) +{ + switch(addr_space) + { + case address_space_enum::generic: return "generic"; + case address_space_enum::global: return "global"; + case address_space_enum::lds: return "lds"; + case address_space_enum::sgpr: return "sgpr"; + case address_space_enum::constant: return "constant"; + case address_space_enum::vgpr: return "vgpr"; + default: return "unknown"; + } +} + } // namespace ck_tile diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 94aa40e278..352c645325 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -177,9 +177,27 @@ struct array CK_TILE_HOST_DEVICE constexpr array() {} CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; } CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; }; - CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } }; +template +CK_TILE_HOST_DEVICE static void print(const array& a) +{ + printf("array{size: %ld, data: [", static_cast(N)); + for(index_t i = 0; i < N; ++i) + { + if(i > 0) + printf(", "); + print(a[i]); + } + printf("]}"); +} + +template +CK_TILE_HOST_DEVICE static void print(const array&) +{ + printf("array{size: 0, data: []}"); +} + template struct vector_traits; diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index 87b180cafc..7697995c92 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -139,26 +139,21 @@ struct map // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("map{size_: %d, ", size_); - // - printf("impl_: ["); - // - for(const auto& [k, d] : *this) - { - printf("{key: "); - print(k); - printf(", data: "); - print(d); - printf("}, "); - } - // - printf("]"); - // - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const map& m) +{ + printf("map{size_: %d, impl_: [", m.size_); + for(const auto& [k, d] : m) + { + printf("{key: "); + print(k); + printf(", data: "); + print(d); + printf("}, "); + } + printf("]}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 94309dd5dd..905b32dd15 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -9,13 +9,10 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { -template -struct static_for; - template struct sequence; @@ -196,15 +193,24 @@ struct sequence { return sequence{}; } - - CK_TILE_HOST_DEVICE static void print() - { - printf("sequence{size: %d, data: [", size()); - ((printf("%d ", Is)), ...); - printf("]}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const sequence&) +{ + printf("sequence<"); + if constexpr(sizeof...(Is) > 0) + { + bool first = true; + (([&first](index_t value) { + printf("%s%d", first ? "" : ", ", value); + first = false; + }(Is)), + ...); + } + printf(">"); +} + namespace impl { template struct __integer_sequence; diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 63d145d8b9..4c48b3d477 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -300,12 +300,29 @@ struct tuple : impl::tuple_base, T...> #undef TP_COM_ }; -template +template +CK_TILE_HOST_DEVICE void print(const tuple& t) +{ + printf("tuple<"); + if constexpr(sizeof...(T) > 0) + { + bool first = true; + static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) { + if(!first) + printf(", "); + print(t.get(i)); + first = false; + }); + } + printf(">"); +} + +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, void> { using scalar_type = __type_pack_element<0, T...>; static constexpr index_t vector_size = sizeof...(T); diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp new file mode 100644 index 0000000000..ea94880f27 --- /dev/null +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ + +struct e8m0_bexp_t +{ + using raw_type = uint8_t; + using type = raw_type; + + raw_type data; + + CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {} + CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {} + CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) + : e8m0_bexp_t(static_cast(numeric_utils::get_exponent(scale))) + { + } + CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr operator float() const; + + constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; } + + constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; } +}; + +using e8m0_t = e8m0_bexp_t; +using e8m0_raw_t = typename e8m0_t::raw_type; + +template <> +struct numeric_traits +{ + using bitwise_type = e8m0_raw_t; + + static constexpr int exp = 8; + static constexpr int mant = 0; + static constexpr int bias = 127; + static constexpr int PackedSize = 1; +}; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127 + static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127 + static constexpr e8m0_raw_t binary_nan = 0b11111111; + CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; } + CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } + + CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); } +}; + +CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const +{ + using traits = numeric_traits; + if(data == numeric::binary_nan) + { + return traits::NaN; + } + else if(data == 0) + { + return std::numeric_limits::min(); + } + else + { + return bit_cast(static_cast(data) << traits::mant); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 33c24da8c5..2ba2fd10c6 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -19,14 +19,18 @@ struct constant CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } }; +template +CK_TILE_HOST_DEVICE static void print(const constant&) +{ + printf("%ld", static_cast(v)); +} + template struct integral_constant : constant { using value_type = T; using type = integral_constant; // using injected-class-name static constexpr T value = v; - // constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } - // constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } // }; template diff --git a/include/ck_tile/core/numeric/mxfp_convert.hpp b/include/ck_tile/core/numeric/mxfp_convert.hpp index b2e138e880..9b378933d0 100644 --- a/include/ck_tile/core/numeric/mxfp_convert.hpp +++ b/include/ck_tile/core/numeric/mxfp_convert.hpp @@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits using traits = numeric_traits; using _numeric = numeric; - using raw_type = typename T::raw_type; + using raw_type = typename traits::bitwise_type; static constexpr int exp_mask = (1 << traits::exp) - 1; - static constexpr int get_exponent(raw_type x) + static constexpr raw_type get_exponent(raw_type x) { // TODO: check if repeated calls are optimized. return (x >> traits::mant) & exp_mask; } + static constexpr raw_type get_exponent(const T& x) + { + return get_exponent(bit_cast(x)); + } static constexpr bool is_positive(raw_type x) { return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero; @@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits static constexpr double get_mantissa(raw_type x) { double mantissa = is_subnormal(x) ? 0.0f : 1.0f; - for(uint32_t i = 0; i < traits::mant; ++i) + for(raw_type i = 0; i < traits::mant; ++i) { mantissa += std::ldexp(static_cast(x & 0b1), -(traits::mant - i)); x >>= 1; @@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits }; template -CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127) +CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f) { - using utils = numeric_utils; - static constexpr int e8m0_bias = 127; // TODO: make it generic. - float sign = utils::is_positive(data) ? 1.0 : -1.0; - int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias; - float mant = utils::get_mantissa(data); + using utils = numeric_utils; + float sign = utils::is_positive(data) ? 1.0 : -1.0; + int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias; + float mant = utils::get_mantissa(data); - return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias); + return std::ldexp(sign * mant * scale, exp); } template -CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value) +CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f) { using bitwise_type = typename numeric_traits::bitwise_type; + value /= scale; + if(std::abs(value) > float(numeric::max())) { float max_value = numeric::max(); diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index 0dee750b69..a345cd1b75 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float); +CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f); // TODO: Add stochastic method struct pk_float4_e2m1_t { - static constexpr int exponent = 2; - static constexpr int mantissa = 1; - static constexpr int bias = 1; // TODO: Can we merge raw_type and type? using raw_type = uint8_t; using type = raw_type; @@ -41,18 +38,27 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast(init)} { } - CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)} + CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f) + : data{float_to_e2m1(init, scale)} { } CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } - CK_TILE_HOST_DEVICE constexpr operator float() const; - CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const; - CK_TILE_HOST_DEVICE constexpr operator fp16_t() const; - CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const; - CK_TILE_HOST_DEVICE constexpr operator bf16_t() const; - CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const; + + CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + + CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); } + CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } + CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } + CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } template CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; @@ -191,131 +197,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f) } // namespace impl #endif -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const +CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return bf16_t{type_convert(convert_to_float(unpack(number<0>{})))}; + return bf16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const + +CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), - type_convert(convert_to_float(unpack(number<1>{})))}; + return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), + type_convert(convert_to_float(unpack(number<1>{}), scale))}; #endif } // TODO: make float_to_e2m1 generic so that we can convert from directrly. -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return convert_to_type(x); + return convert_to_type(x, scale); #endif } -CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); } -CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); } -CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale) +{ + return float_to_e2m1(x, scale); +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x)); + return float_to_e2m1(type_convert(x), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x)); + return float_to_e2m1(type_convert(x), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), - float_to_e2m1(type_convert(x[1]))); + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0]), scale), + float_to_e2m1(type_convert(x[1]), scale)); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), - float_to_e2m1(type_convert(x[1]))); + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0]), scale), + float_to_e2m1(type_convert(x[1]), scale)); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1])); + return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); #endif } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale) +{ + return x.to_fp32x2(scale); +} +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale) +{ + return x.to_fp16x2(scale); +} +CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale) +{ + return x.to_bf16x2(scale); +} +CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale) +{ + return x.to_float(scale); +} +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale) +{ + return x.to_fp16(scale); +} +CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale) +{ + return x.to_bf16(scale); +} + #if TEST_convert_with_table == 0 -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return convert_to_float(unpack(number<0>{})); + return convert_to_float(unpack(number<0>{}), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp32x2_t{convert_to_float(unpack(number<0>{})), - convert_to_float(unpack(number<1>{}))}; + return fp32x2_t{convert_to_float(unpack(number<0>{}), scale), + convert_to_float(unpack(number<1>{}), scale)}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const + +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp16_t{type_convert(convert_to_float(unpack(number<0>{})))}; + return fp16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), - type_convert(convert_to_float(unpack(number<1>{})))}; + return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), + type_convert(convert_to_float(unpack(number<1>{}), scale))}; #endif } #else -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { - return e2m1_to_fp32_table[data & 0xf]; + return e2m1_to_fp32_table[unpack(number<0>{})] * scale; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]}; + return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale}; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { - return e2m1_to_fp16_table[data & 0xf]; + return type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { - return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]}; + return fp16x2_t{ + type_convert(type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)}; } #endif diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 94d6e3cd34..1455fce0ea 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) +#undef CK_TILE_TYPE_CONVERT } // namespace ck_tile @@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) namespace ck_tile { -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2) -CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2) -CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2) -CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16) -#undef CK_TILE_TYPE_CONVERT +template +CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale); + +#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert(stype_ x, \ + float scale) \ + { \ + return sname_##_to_##dname_(x, scale); \ + } \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return sname_##_to_##dname_(x, 1.f); \ + } + +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2) +CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2) +CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2) +CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float) +CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16) +CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16) +CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4) +#undef CK_TILE_SCALED_TYPE_CONVERT + #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index b165275a8c..58bdb43b08 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector::type; // by default, any type will result in a vector_size=1 with scalar_type=T traits. // ... unless we have other vector_traits specialization -template +template struct vector_traits { using scalar_type = @@ -94,7 +94,7 @@ struct vector_traits // specialization for ext_vector_type() template -struct vector_traits +struct vector_traits { using scalar_type = std::conditional_t, int8_t, T>; static constexpr index_t vector_size = N; diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 4b39773939..ca314a6abe 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -210,28 +210,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: Global @@ -757,28 +735,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: LDS @@ -1138,28 +1094,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: Vgpr @@ -1313,28 +1247,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; template +CK_TILE_HOST_DEVICE void print(const buffer_view& bv) +{ + printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ", + address_space_to_string(BufferAddressSpace), + static_cast(const_cast*>(bv.p_data_))); + print(bv.buffer_size_); + printf(", invalid_element_value_: "); + print(bv.invalid_element_value_); + printf("}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index de99be1965..f7eca73afb 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,10 +53,13 @@ struct is_null_tile_window> : public std::true_type }; } // namespace impl +template +constexpr bool is_null_tile_window_v = impl::is_null_tile_window>::value; + template CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&) { - return impl::is_null_tile_window>::value; + return is_null_tile_window_v>; } template diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index e2a6ae6555..ec5538d79c 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -305,42 +305,45 @@ struct tensor_adaptor get_container_subset(vector_strides, top_dims)); } - CK_TILE_HOST_DEVICE void print() const - { - printf("tensor_adaptor{"); - - // - printf("transforms: "); - print(transforms_); - printf(", "); - - // - printf("LowerDimensionHiddenIds: "); - print(LowerDimensionHiddenIdss{}); - printf(", "); - - // - printf("UpperDimensionHiddenIds: "); - print(UpperDimensionHiddenIdss{}); - printf(", "); - - // - printf("BottomDimensionHiddenIds: "); - print(BottomDimensionHiddenIds{}); - printf(", "); - - // - printf("TopDimensionHiddenIds: "); - print(TopDimensionHiddenIds{}); - - printf("}"); - } - private: Transforms transforms_; ElementSize element_size_; }; +template +CK_TILE_HOST_DEVICE static void print(const tensor_adaptor& adaptor) +{ + printf("tensor_adaptor{\n"); + printf(" transforms: ["); + print(adaptor.get_transforms()); + printf("],\n"); + + printf(" LowerDimensionHiddenIds: ["); + print(LowerDimensionHiddenIdss{}); + printf("],\n"); + + printf(" UpperDimensionHiddenIds: ["); + print(UpperDimensionHiddenIdss{}); + printf("],\n"); + + printf(" BottomDimensionHiddenIds: ["); + print(BottomDimensionHiddenIds{}); + printf("],\n"); + + // + printf(" TopDimensionHiddenIds: ["); + print(TopDimensionHiddenIds{}); + printf("]\n}\n"); +} + // Transforms: Tuple // LowerDimensionOldTopIdss: Tuple, ...> // UpperDimensionNewTopIdss: Tuple, ...> diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 0c3e04f315..0e4787a2f1 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor(GuaranteedVectorStrides{})); } - CK_TILE_HOST_DEVICE void print() const - { - printf("tensor_descriptor{"); - - // tensor_adaptor - Base::print(); - printf(", "); - - // element_space_size_ - printf("element_space_size_: "); - print(element_space_size_); - - printf("}"); - } - // TODO make these private ElementSpaceSize element_space_size_; }; +template +CK_TILE_HOST_DEVICE static void print(const tensor_descriptor& descriptor) +{ + printf("tensor_descriptor{\n"); + // first print the tensor adaptor part of the descriptor using the base class print + print(static_cast(descriptor)); + printf("element_space_size_: %ld,\n", + static_cast(descriptor.get_element_space_size().value)); + printf("guaranteed_vector_lengths: "); + print(GuaranteedVectorLengths{}); + printf(",\nguaranteed_vector_strides: "); + print(GuaranteedVectorStrides{}); + printf("}\n}\n"); +} + template CK_TILE_HOST_DEVICE constexpr auto make_tensor_descriptor_from_adaptor(const Adaptor& adaptor, diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 11e6b35c39..bc02ec74d2 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -228,24 +228,6 @@ struct tile_distribution { return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution{"); - // - printf("tile_distribution_encoding: "); - print(DstrEncode{}); - printf(", "); - // - printf("ps_ys_to_xs_: "); - print(ps_ys_to_xs_); - printf(", "); - // - printf("ys_to_d_: "); - print(ys_to_d_); - // - printf("}"); - } }; namespace detail { @@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( } } // namespace detail + +// Free print function for tile_distribution +template +CK_TILE_HOST_DEVICE void print(const tile_distribution& distribution) +{ + printf("tile_distribution{"); + printf("tile_distribution_encoding: "); + print(StaticTileDistributionEncoding_{}); + printf(", "); + printf("ps_ys_to_xs_: "); + print(distribution.ps_ys_to_xs_); + printf(", "); + printf("ys_to_d_: "); + print(distribution.ys_to_d_); + printf("}\n"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index b380e7c9d8..90d1a2ccb2 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -428,109 +428,7 @@ struct tile_distribution_encoding { return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum()); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution_encoding::detail{"); - // - printf("ndim_rh_major_: "); - print(ndim_rh_major_); - printf(", "); - // - printf("ndim_span_major_: "); - print(ndim_span_major_); - printf(", "); - // - printf("ndims_rhs_minor_: "); - print(ndims_rhs_minor_); - printf(", "); - // - printf("ndim_rh_major_: "); - print(ndim_rh_major_); - printf(", "); - // - printf("max_ndim_rh_minor_: "); - print(max_ndim_rh_minor_); - printf(", "); - // - printf("rhs_lengthss_: "); - print(rhs_lengthss_); - printf(", "); - // - printf("ys_lengths_: "); - print(ys_lengths_); - printf(", "); - // - printf("rhs_major_minor_to_ys_: "); - print(rhs_major_minor_to_ys_); - printf(", "); - // - printf("ndims_span_minor_: "); - print(ndims_span_minor_); - printf(", "); - // - printf("max_ndim_span_minor_: "); - print(max_ndim_span_minor_); - printf(", "); - // - printf("ys_to_span_major_: "); - print(ys_to_span_major_); - printf(", "); - // - printf("ys_to_span_minor_: "); - print(ys_to_span_minor_); - printf(", "); - // - printf("distributed_spans_lengthss_: "); - print(distributed_spans_lengthss_); - printf(", "); - // - printf("ndims_distributed_spans_minor_: "); - print(ndims_distributed_spans_minor_); - printf(", "); - // - printf("ps_over_rs_derivative_: "); - print(ps_over_rs_derivative_); - // - printf("}"); - } }; - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution_encoding{"); - // - printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY); - // - printf("rs_lengths_: "); - print(rs_lengths_); - printf(", "); - // - printf("hs_lengthss_: "); - print(hs_lengthss_); - printf(", "); - // - printf("ps_to_rhss_major_: "); - print(ps_to_rhss_major_); - printf(", "); - // - printf("ps_to_rhss_minor_: "); - print(ps_to_rhss_minor_); - printf(", "); - // - printf("ys_to_rhs_major_: "); - print(ys_to_rhs_major_); - printf(", "); - // - printf("ys_to_rhs_minor_: "); - print(ys_to_rhs_minor_); - printf(", "); - // - printf("detail: "); - print(detail{}); - // - printf("}"); - } }; template @@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence reduce } } // namespace detail + +// Free print function for tile_distribution_encoding::detail +template +CK_TILE_HOST_DEVICE void +print(const typename tile_distribution_encoding::detail& detail_obj) +{ + printf("tile_distribution_encoding::detail{"); + printf("ndim_rh_major_: "); + print(detail_obj.ndim_rh_major_); + printf(", "); + printf("ndim_span_major_: "); + print(detail_obj.ndim_span_major_); + printf(", "); + printf("ndims_rhs_minor_: "); + print(detail_obj.ndims_rhs_minor_); + printf(", "); + printf("ndim_rh_major_: "); + print(detail_obj.ndim_rh_major_); + printf(", "); + printf("max_ndim_rh_minor_: "); + print(detail_obj.max_ndim_rh_minor_); + printf(", "); + printf("rhs_lengthss_: "); + print(detail_obj.rhs_lengthss_); + printf(", "); + printf("ys_lengths_: "); + print(detail_obj.ys_lengths_); + printf(", "); + printf("rhs_major_minor_to_ys_: "); + print(detail_obj.rhs_major_minor_to_ys_); + printf(", "); + printf("ndims_span_minor_: "); + print(detail_obj.ndims_span_minor_); + printf(", "); + printf("max_ndim_span_minor_: "); + print(detail_obj.max_ndim_span_minor_); + printf(", "); + printf("ys_to_span_major_: "); + print(detail_obj.ys_to_span_major_); + printf(", "); + printf("ys_to_span_minor_: "); + print(detail_obj.ys_to_span_minor_); + printf(", "); + printf("distributed_spans_lengthss_: "); + print(detail_obj.distributed_spans_lengthss_); + printf(", "); + printf("ndims_distributed_spans_minor_: "); + print(detail_obj.ndims_distributed_spans_minor_); + printf(", "); + printf("ps_over_rs_derivative_: "); + print(detail_obj.ps_over_rs_derivative_); + printf("}"); +} + +// Free print function for tile_distribution_encoding +template +CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding& encoding) +{ + printf("tile_distribution_encoding{"); + + printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY); + printf("rs_lengths_: "); + print(encoding.rs_lengths_); + printf(", "); + printf("hs_lengthss_: "); + print(encoding.hs_lengthss_); + printf(", "); + printf("ps_to_rhss_major_: "); + print(encoding.ps_to_rhss_major_); + printf(", "); + printf("ps_to_rhss_minor_: "); + print(encoding.ps_to_rhss_minor_); + printf(", "); + printf("ys_to_rhs_major_: "); + print(encoding.ys_to_rhs_major_); + printf(", "); + printf("ys_to_rhs_minor_: "); + print(encoding.ys_to_rhs_minor_); + printf(", "); + printf("}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/utility/print.hpp b/include/ck_tile/core/utility/print.hpp new file mode 100644 index 0000000000..04635959af --- /dev/null +++ b/include/ck_tile/core/utility/print.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +/// Declare a ck_tile::print() interface that gets specialized in each header file for types that +/// can be printed. +template +CK_TILE_HOST_DEVICE void print(const T&) +{ + static_assert(sizeof(T) == 0, + "No print implementation available for this type. Please specialize " + "ck_tile::print for your type."); +} + +/// Specialization for int +template <> +CK_TILE_HOST_DEVICE void print(const int& value) +{ + printf("%d", value); +} + +/// Specialization for float +template <> +CK_TILE_HOST_DEVICE void print(const float& value) +{ + printf("%f", value); +} + +/// Specialization for double +template <> +CK_TILE_HOST_DEVICE void print(const double& value) +{ + printf("%f", value); +} + +/// Specialization for long +template <> +CK_TILE_HOST_DEVICE void print(const long& value) +{ + printf("%ld", value); +} + +/// Specialization for unsigned int +template <> +CK_TILE_HOST_DEVICE void print(const unsigned int& value) +{ + printf("%u", value); +} + +/// Specialization for char +template <> +CK_TILE_HOST_DEVICE void print(const char& value) +{ + printf("%c", value); +} + +/// Specialization for array +template +CK_TILE_HOST_DEVICE void print(const T (&value)[N]) +{ + printf("["); + for(size_t i = 0; i < N; ++i) + { + if(i > 0) + printf(", "); + print(value[i]); // Recursively call print for each element + } + printf("]"); +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index c3f1b7d221..b7329fcac7 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -409,7 +409,13 @@ struct HostTensor } // void SetZero() { ck_tile::ranges::fill(mData, 0); } - void SetZero() { std::fill(mData.begin(), mData.end(), 0); } + void SetZero() + { + if constexpr(std::is_same_v) + std::fill(mData.begin(), mData.end(), e8m0_t{1.f}); + else + std::fill(mData.begin(), mData.end(), 0); + } template void ForEach_impl(F&& f, std::vector& idx, size_t rank) diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 30bea193b7..313de5f29a 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -24,8 +24,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index ce3bf8fe8d..8b184b18f3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -52,8 +52,6 @@ struct FmhaBwdDQDKDVKernel using BiasGradDataType = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; @@ -85,8 +83,6 @@ struct FmhaBwdDQDKDVKernel #define _TS_ std::to_string auto pn = [&] () { std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; if (kPadHeadDimQ) n += "d"; if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); @@ -100,7 +96,7 @@ struct FmhaBwdDQDKDVKernel "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ); @@ -1221,7 +1217,7 @@ struct FmhaBwdDQDKDVKernel const auto q_dram = pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -1232,7 +1228,7 @@ struct FmhaBwdDQDKDVKernel const auto k_dram = pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -1244,22 +1240,15 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view_packed( - lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); - return pad_tensor_view( - lse_dram_naive, make_tuple(number{}), sequence{}); - }(); + // lse and d should be fine to read unpaded data as they are not on the reduction dimension + const auto lse_dram = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number{}); - const auto d_dram = [&]() { - const auto d_dram_naive = make_naive_tensor_view_packed( - d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); - return pad_tensor_view( - d_dram_naive, make_tuple(number{}), sequence{}); - }(); + const auto d_dram = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number{}); const auto do_dram_naive = make_naive_tensor_view( do_ptr, @@ -1270,7 +1259,7 @@ struct FmhaBwdDQDKDVKernel const auto do_dram = pad_tensor_view( do_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); auto q_dram_window = make_tile_window( q_dram, @@ -1313,7 +1302,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); return make_tile_window( @@ -1341,7 +1330,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); return make_tile_window( @@ -1376,9 +1365,8 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + bias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); @@ -1406,9 +1394,8 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view(dbias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + dbias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); @@ -1495,9 +1482,8 @@ struct FmhaBwdDQDKDVKernel number<1>{}, number<1>{}); - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); + return pad_tensor_view( + randval_dram_naive, randval_dram_window_lengths, sequence{}); }(); return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); @@ -1550,7 +1536,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dk_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto dv_dram = [&]() { @@ -1564,7 +1550,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dv_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto dk_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 8a13c0b060..1f11569533 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; @@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr"; @@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index c88b058d32..967fe2362d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; @@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr_iglp"; @@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); @@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); } - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp new file mode 100644 index 0000000000..80c311de86 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" + +namespace ck_tile { + +template +class BlockFmhaBwdDQDKDVPipelineSelector +{ + static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; + + public: + using type = std::conditional_t, + BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>; +}; + +template +class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type +{ + public: + static constexpr const char* name = "auto"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp deleted file mode 100644 index 27f58ef2f8..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck_tile { - -// This class is used for codegen pattern matching -enum class BlockFmhaBwdPipelineEnum -{ - KRKTRVR_IGLP = 0, - KRKTRVR, -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index c4c4a745a7..f6c79c7db6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kIsDeterministic = kIsDeterministic_; // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); + static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); }; template + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using ck_tile::bf16_t; +using ck_tile::bf16x2_t; +using ck_tile::fp16_t; +using ck_tile::fp16x2_t; +using ck_tile::fp32_t; +using ck_tile::fp32x2_t; +using ck_tile::number; +using ck_tile::pk_fp4_t; + +template +CK_TILE_HOST void test_convert(); + +using ck_tile::e8m0_raw_t; +using ck_tile::e8m0_t; + +TEST(OCP_Scale, NumericLimits) +{ + EXPECT_EQ(ck_tile::numeric::has_inf(), false); + EXPECT_EQ(ck_tile::numeric::zero(), ck_tile::numeric::signaling_NaN()); + EXPECT_EQ(ck_tile::numeric::min(), e8m0_t{e8m0_raw_t{0b00000000}}); + EXPECT_EQ(ck_tile::numeric::max(), e8m0_t{e8m0_raw_t{0b11111110}}); +} +TEST(OCP_Scale, NumericBasic) +{ + auto scale_1 = e8m0_t{1.0f}; + auto scale_2 = e8m0_t{e8m0_raw_t{ck_tile::numeric_traits::bias}}; // 2^0 + EXPECT_EQ(scale_1, scale_2); + + auto scale_3 = e8m0_t{8.0f}; + auto scale_4 = e8m0_t{e8m0_raw_t{3 + ck_tile::numeric_traits::bias}}; // 2^3 + EXPECT_EQ(scale_3, scale_4); +} + +TEST(OCP_Scale, ScaledConvertDevice) +{ + constexpr bool is_device = true; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} +TEST(OCP_Scale, ScaledConvertHost) +{ + constexpr bool is_device = false; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} +TEST(OCP_Scale, tensorInit) +{ + using scale_t = e8m0_t; + ck_tile::HostTensor scales({10, 10}); + ck_tile::FillUniformDistribution{1.f, 1.f}(scales); + scales.SetZero(); +} + +#define toPF4(x, y) ck_tile::scaled_type_convert(x, y) +#define toDST(x, y) ck_tile::scaled_type_convert(x, y) +#define toDSTx2(x, y) ck_tile::scaled_type_convert(x, y) + +#define toF32(x) ck_tile::type_convert(x) +#define toPF4_(x) ck_tile::type_convert(x) +#define toSRC(x) ck_tile::type_convert(x) +#define toDST_(x) ck_tile::type_convert(x) + +template +__global__ void MyKernel(Args... args) +{ + Kernel{}(args...); +} +template +struct SrcPkfp4Dst +{ + CK_TILE_HOST_DEVICE void + operator()(const SRC* src, DST* dst, e8m0_t scale1, e8m0_t scale2) const + { + + using SRCx2_t = ck_tile::ext_vector_t; + using DSTx2_t = ck_tile::ext_vector_t; + + ck_tile::static_for<0, N, 2>{}([&](auto i) { + const auto input2 = SRCx2_t{src[i], src[i + 1]}; + + if(i % 4 == 0) + { + // ex: fp32_t -> fp4 -> bf16_t + dst[i] = toDST(toPF4(src[i], scale1), scale2); + // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t + dst[i + 1] = toDST(toPF4_(toPF4(input2, scale1).unpack(number<1>{})), scale2); + } + else + { + // ex: fp32x2_t -> pk_fp4_t -> bf16x2_t + reinterpret_cast(dst)[i >> 1] = toDSTx2(toPF4(input2, scale1), scale2); + } + }); + } +}; + +template +CK_TILE_HOST void test_convert() +{ + const auto test_data = std::array{4.f, 6.f, 8.f, 10.f}; + const auto ref_data = std::array{8.f, 16.f, 16.f, 16.f}; + const auto scale1 = e8m0_t{8.0f}; + const auto scale2 = e8m0_t{16.0f}; + + static_assert(test_data.size() == ref_data.size()); + static_assert(test_data.size() % 2 == 0); + + constexpr int N = test_data.size(); + std::array in; + std::array ref, out; + + // prepare input and ground truth in host + for(int i = 0; i < N; ++i) + { + in[i] = toSRC(test_data[i]); + ref[i] = toDST_(ref_data[i]); + EXPECT_EQ(test_data[i], toF32(in[i])); + EXPECT_EQ(ref_data[i], toF32(ref[i])); + } + + using job = SrcPkfp4Dst; + + if constexpr(is_device) + { + auto in_d = std::make_unique(in.size() * sizeof(SRC)); + auto out_d = std::make_unique(out.size() * sizeof(DST)); + in_d->ToDevice(in.data()); + + MyKernel<<<1, 1>>>(reinterpret_cast(in_d->GetDeviceBuffer()), + reinterpret_cast(out_d->GetDeviceBuffer()), + scale1, + scale2); + + out_d->FromDevice(out.data()); + } + else + { + job{}(in.data(), out.data(), scale1, scale2); + } + + for(int i = 0; i < N; ++i) + EXPECT_EQ(ref[i], out[i]) << "i:" << i; +} diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt new file mode 100644 index 0000000000..c57cafca5a --- /dev/null +++ b/test/ck_tile/utility/CMakeLists.txt @@ -0,0 +1,4 @@ +message("-- Adding: test/ck_tile/utility/") + +# Add print tests +add_subdirectory(print) diff --git a/test/ck_tile/utility/print/CMakeLists.txt b/test/ck_tile/utility/print/CMakeLists.txt new file mode 100644 index 0000000000..5300dd20ca --- /dev/null +++ b/test/ck_tile/utility/print/CMakeLists.txt @@ -0,0 +1,8 @@ +# Print utility tests +add_gtest_executable(test_print_sequence test_print_sequence.cpp) +add_gtest_executable(test_print_array test_print_array.cpp) +add_gtest_executable(test_print_tuple test_print_tuple.cpp) +add_gtest_executable(test_print_coordinate_transform test_print_coordinate_transform.cpp) +add_gtest_executable(test_print_static_encoding_pattern test_print_static_encoding_pattern.cpp) +add_gtest_executable(test_print_buffer_view test_print_buffer_view.cpp) +add_gtest_executable(test_print_basic_types test_print_basic_types.cpp) diff --git a/test/ck_tile/utility/print/README.md b/test/ck_tile/utility/print/README.md new file mode 100644 index 0000000000..558c6faee4 --- /dev/null +++ b/test/ck_tile/utility/print/README.md @@ -0,0 +1,70 @@ +# Print Function Tests + +This directory contains unit tests for testing the print functionality of various data structures and coordinate transformations in the composable_kernel library. + +## Tests Included + +### test_print_sequence.cpp +Tests the print functionality for `sequence<...>` containers: +- Simple sequences with multiple elements +- Single element sequences +- Empty sequences +- Longer sequences + +### test_print_array.cpp +Tests the print functionality for `array` containers: +- Arrays with integer values +- Single element arrays +- Empty arrays (size 0) +- Arrays with floating point values + +### test_print_tuple.cpp +Tests the print functionality for `tuple<...>` containers: +- Simple tuples with numbers +- Single element tuples +- Empty tuples +- Mixed type tuples + +### test_print_coordinate_transform.cpp +Tests the print functionality for coordinate transformation structures: +- `pass_through` transform +- `embed` transform +- `merge` transform +- `unmerge` transform +- `freeze` transform + +## Testing Approach + +All tests use Google Test's `CaptureStdout()` functionality to capture the output from print functions and verify the formatting: + +```cpp +testing::internal::CaptureStdout(); +print(object); +std::string output = testing::internal::GetCapturedStdout(); +EXPECT_EQ(output, "expected_format"); +``` + +This approach enables testing of print function output without affecting the console during test execution. + +## Building and Running + +The tests are integrated into the CMake build system. To build and run the print tests: + +```bash +# Build the specific test +make test_print_sequence + +# Run the test +./test_print_sequence + +# Or run all print tests using CTest +ctest -R "test_print" +``` + +## Adding New Tests + +To add tests for new data structures: + +1. Create a new test file: `test_print_.cpp` +2. Follow the existing pattern using `CaptureStdout()` +3. Add the test executable to `CMakeLists.txt` diff --git a/test/ck_tile/utility/print/test_print_array.cpp b/test/ck_tile/utility/print/test_print_array.cpp new file mode 100644 index 0000000000..2fe9bc2a0c --- /dev/null +++ b/test/ck_tile/utility/print/test_print_array.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintArrayTest : public PrintTest +{ +}; + +TEST_F(PrintArrayTest, PrintIntArray) +{ + // Test printing array + array arr{10, 20, 30}; + + std::string output = CapturePrintOutput(arr); + + // The expected format should match the array print function implementation + EXPECT_EQ(output, "array{size: 3, data: [10, 20, 30]}"); +} + +TEST_F(PrintArrayTest, PrintSingleElementArray) +{ + // Test printing array + array arr{42}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "array{size: 1, data: [42]}"); +} + +TEST_F(PrintArrayTest, PrintEmptyArray) +{ + // Test printing array (empty array) + array arr{}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "array{size: 0, data: []}"); +} + +TEST_F(PrintArrayTest, PrintFloatArray) +{ + // Test printing array with float values + array arr{3.14f, 2.71f}; + + std::string output = CapturePrintOutput(arr); + + // Note: float printing format may vary, so we'll test for basic structure + EXPECT_TRUE(output.find("array{size: 2, data: [") == 0); + EXPECT_TRUE(output.find("3.14") != std::string::npos); + EXPECT_TRUE(output.find("2.71") != std::string::npos); + EXPECT_TRUE(output.find("]}") == output.length() - 2); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_basic_types.cpp b/test/ck_tile/utility/print/test_print_basic_types.cpp new file mode 100644 index 0000000000..7a26b6371a --- /dev/null +++ b/test/ck_tile/utility/print/test_print_basic_types.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintBasicTypesTest : public PrintTest +{ +}; + +TEST_F(PrintBasicTypesTest, PrintIntArray) +{ + int arr[4] = {1, 2, 3, 4}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[1, 2, 3, 4]"); +} + +TEST_F(PrintBasicTypesTest, PrintFloatArray) +{ + float arr[3] = {1.5f, 2.5f, 3.5f}; + + std::string output = CapturePrintOutput(arr); + + // Note: floating point formatting may vary, so we check for key elements + EXPECT_TRUE(output.find("[") == 0); + EXPECT_TRUE(output.find("1.5") != std::string::npos); + EXPECT_TRUE(output.find("2.5") != std::string::npos); + EXPECT_TRUE(output.find("3.5") != std::string::npos); + EXPECT_TRUE(output.back() == ']'); + EXPECT_TRUE(output.find(", ") != std::string::npos); +} + +TEST_F(PrintBasicTypesTest, PrintDoubleArray) +{ + double arr[2] = {10.123, 20.456}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_TRUE(output.find("[") == 0); + EXPECT_TRUE(output.find("10.123") != std::string::npos); + EXPECT_TRUE(output.find("20.456") != std::string::npos); + EXPECT_TRUE(output.back() == ']'); +} + +TEST_F(PrintBasicTypesTest, PrintUnsignedIntArray) +{ + unsigned int arr[3] = {100u, 200u, 300u}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[100, 200, 300]"); +} + +TEST_F(PrintBasicTypesTest, PrintCharArray) +{ + char arr[5] = {'a', 'b', 'c', 'd', 'e'}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[a, b, c, d, e]"); +} + +TEST_F(PrintBasicTypesTest, PrintSingleElementArray) +{ + int arr[1] = {42}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[42]"); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_buffer_view.cpp b/test/ck_tile/utility/print/test_print_buffer_view.cpp new file mode 100644 index 0000000000..66668a2103 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_buffer_view.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/tensor/buffer_view.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintBufferViewTest : public PrintTest +{ +}; + +TEST_F(PrintBufferViewTest, PrintGenericBufferView) +{ + // Test printing generic address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(&data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: generic") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintGlobalBufferView) +{ + // Test printing global address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(&data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: global") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintLdsBufferView) +{ + // Test printing LDS address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: lds") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintVgprBufferView) +{ + // Test printing VGPR address space buffer_view + float data[4] = {1.5f, 2.5f, 3.5f, 4.5f}; + auto bv = make_buffer_view(data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: vgpr") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_common.hpp b/test/ck_tile/utility/print/test_print_common.hpp new file mode 100644 index 0000000000..3ba2270802 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_common.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/utility/print.hpp" + +class PrintTest : public ::testing::Test +{ + protected: + void SetUp() override {} + void TearDown() override {} + // Helper function to capture and return the output of a print function + template + std::string CapturePrintOutput(const T& type) + { + using namespace ck_tile; + testing::internal::CaptureStdout(); + print(type); + return testing::internal::GetCapturedStdout(); + } +}; diff --git a/test/ck_tile/utility/print/test_print_coordinate_transform.cpp b/test/ck_tile/utility/print/test_print_coordinate_transform.cpp new file mode 100644 index 0000000000..639b113eb7 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_coordinate_transform.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintCoordinateTransformTest : public PrintTest +{ +}; + +TEST_F(PrintCoordinateTransformTest, PrintPassThrough) +{ + // Test printing pass_through transform + auto pt = make_pass_through_transform(number<32>{}); + + std::string output = CapturePrintOutput(pt); + + // Verify it contains the pass_through identifier and some structure + EXPECT_TRUE(output.find("pass_through{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintEmbed) +{ + // Test printing embed transform + auto embed_transform = make_embed_transform(make_tuple(number<4>{}, number<8>{}), + make_tuple(number<1>{}, number<4>{})); + + std::string output = CapturePrintOutput(embed_transform); + + // Verify it contains the embed identifier and key fields + EXPECT_TRUE(output.find("embed{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.find("coefficients_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintMerge) +{ + // Test printing merge transform + auto merge_transform = make_merge_transform(make_tuple(number<4>{}, number<8>{})); + + std::string output = CapturePrintOutput(merge_transform); + + // Verify it contains merge identifier and key fields + EXPECT_TRUE(output.find("merge") == + 0); // Could be merge_v2_magic_division or merge_v3_division_mod + EXPECT_TRUE(output.find("low_lengths_") != std::string::npos || + output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintUnmerge) +{ + // Test printing unmerge transform + auto unmerge_transform = make_unmerge_transform(make_tuple(number<4>{}, number<8>{})); + + std::string output = CapturePrintOutput(unmerge_transform); + + // Verify it contains the unmerge identifier and key fields + EXPECT_TRUE(output.find("unmerge{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintFreeze) +{ + // Test printing freeze transform + auto freeze_transform = make_freeze_transform(number<5>{}); + + std::string output = CapturePrintOutput(freeze_transform); + + // Verify it contains the freeze identifier and key fields + EXPECT_TRUE(output.find("freeze{") == 0); + EXPECT_TRUE(output.find("low_idx_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_sequence.cpp b/test/ck_tile/utility/print/test_print_sequence.cpp new file mode 100644 index 0000000000..e73a9f7e33 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_sequence.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/utility/print.hpp" +#include "ck_tile/core/container/sequence.hpp" + +namespace ck_tile { + +class PrintSequenceTest : public PrintTest +{ +}; + +TEST_F(PrintSequenceTest, PrintSimpleSequence) +{ + // Test printing sequence<1, 5, 8> + constexpr auto seq = sequence<1, 5, 8>{}; + + std::string output = CapturePrintOutput(seq); + + // Verify the output format + EXPECT_EQ(output, "sequence<1, 5, 8>"); +} + +TEST_F(PrintSequenceTest, PrintSingleElementSequence) +{ + // Test printing sequence<42> + constexpr auto seq = sequence<42>{}; + + std::string output = CapturePrintOutput(seq); + + EXPECT_EQ(output, "sequence<42>"); +} + +TEST_F(PrintSequenceTest, PrintEmptySequence) +{ + // Test printing sequence<> (empty sequence) + constexpr auto seq = sequence<>{}; + + std::string output = CapturePrintOutput(seq); + + EXPECT_EQ(output, "sequence<>"); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp new file mode 100644 index 0000000000..d1cb408b5c --- /dev/null +++ b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/algorithm/static_encoding_pattern.hpp" +#include "ck_tile/core/utility/print.hpp" + +#include + +namespace ck_tile { + +class PrintStaticEncodingPatternTest : public PrintTest +{ + protected: + void TestY0Y1Y2(const std::string& output, auto Y0, auto Y1, auto Y2) + { + std::stringstream expected; + expected << ": <" << Y0 << ", " << Y1 << ", " << Y2 << ">"; + EXPECT_TRUE(output.find(expected.str()) != std::string::npos); + } + void TestX0X1(const std::string& output, auto X0, auto X1) + { + std::stringstream expected; + expected << ": <" << X0 << ", " << X1 << ">"; + EXPECT_TRUE(output.find(expected.str()) != std::string::npos); + } +}; + +TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern) +{ + // Test printing thread raked pattern + using PatternType = + TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:4") != std::string::npos); + EXPECT_TRUE(output.find("thread_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern) +{ + // Test printing warp raked pattern + using PatternType = + TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:8") != std::string::npos); + EXPECT_TRUE(output.find("warp_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern) +{ + // Test printing block raked pattern + using PatternType = + TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:16") != std::string::npos); + EXPECT_TRUE(output.find("block_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_tuple.cpp b/test/ck_tile/utility/print/test_print_tuple.cpp new file mode 100644 index 0000000000..79aaf1b3af --- /dev/null +++ b/test/ck_tile/utility/print/test_print_tuple.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintTupleTest : public PrintTest +{ +}; + +TEST_F(PrintTupleTest, PrintSimpleTuple) +{ + // Test printing tuple with numbers + auto tup = make_tuple(number<1>{}, number<5>{}, number<8>{}); + + std::string output = CapturePrintOutput(tup); + + // Verify the output format matches tuple print implementation + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("1") != std::string::npos); + EXPECT_TRUE(output.find("5") != std::string::npos); + EXPECT_TRUE(output.find("8") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +TEST_F(PrintTupleTest, PrintSingleElementTuple) +{ + // Test printing tuple with single element + auto tup = make_tuple(number<42>{}); + + std::string output = CapturePrintOutput(tup); + + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("42") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +TEST_F(PrintTupleTest, PrintEmptyTuple) +{ + // Test printing empty tuple + auto tup = make_tuple(); + + std::string output = CapturePrintOutput(tup); + + EXPECT_EQ(output, "tuple<>"); +} + +TEST_F(PrintTupleTest, PrintMixedTypeTuple) +{ + // Test printing tuple with mixed types (numbers and constants) + auto tup = make_tuple(number<10>{}, constant<20>{}, number<30>{}); + + std::string output = CapturePrintOutput(tup); + + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("10") != std::string::npos); + EXPECT_TRUE(output.find("20") != std::string::npos); + EXPECT_TRUE(output.find("30") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +} // namespace ck_tile