mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit '5d6d236b255b4ef9c8f38e1bd35975acda0af19a' into develop
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
@@ -8,21 +8,13 @@ import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Literal
|
||||
from collections import defaultdict
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
BWD_DQDKDV_PIPELINE_MAP = {
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
|
||||
}
|
||||
|
||||
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
|
||||
}
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
@@ -56,8 +48,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile0_{F_idx}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
|
||||
false, /* kPadSeqLenK */
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false,
|
||||
@@ -93,18 +85,18 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
fmha_dropout_{F_idx},
|
||||
fmha_bwd_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline<fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
{F_skpad},
|
||||
false,
|
||||
{F_dpad}>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
{F_skpad},
|
||||
false,
|
||||
{F_dvpad}>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
@@ -115,13 +107,10 @@ using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_pipeline_enum},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic}>;
|
||||
@@ -195,15 +184,18 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
M0_1D = 64
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
||||
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
||||
@@ -249,8 +241,6 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
@@ -259,7 +249,6 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_deterministic : str #
|
||||
F_pipeline : str #
|
||||
mask_impl : str #
|
||||
|
||||
@property
|
||||
@@ -293,8 +282,6 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
@@ -304,21 +291,18 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
@@ -347,20 +331,15 @@ class FmhaBwdDQDKDVKernel:
|
||||
return self.name + ".cpp"
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
# this is current supported tile size.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -375,7 +354,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = */ 64,
|
||||
/* BlockSize = M0 = */ 64,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
@@ -580,7 +559,6 @@ class FmhaBwdConvertQGradKernel:
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdApiTrait:
|
||||
idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : int
|
||||
dtype : str # data type
|
||||
@@ -590,9 +568,7 @@ class FmhaBwdApiTrait:
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
spad1 : str # spad for dot/convert kernel
|
||||
skpad : str
|
||||
spad1d : str # spad for 1d kernels (dot/convert)
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
@@ -611,24 +587,14 @@ class FmhaBwdApiTrait:
|
||||
def bhdv(self) -> int:
|
||||
return self.tile.F_bhdv
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return 'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.spad1d == 't':
|
||||
return f'a.seqlen_q % {M0_1D} != 0'
|
||||
else: # self.spad1d == 'f'
|
||||
return f'a.seqlen_q % {M0_1D} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
@@ -647,14 +613,14 @@ class FmhaBwdApiTrait:
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1,
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d,
|
||||
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
|
||||
F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias,
|
||||
F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl)
|
||||
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
|
||||
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl)
|
||||
|
||||
@property
|
||||
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
|
||||
@@ -664,48 +630,46 @@ class FmhaBwdApiTrait:
|
||||
return 2
|
||||
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list))
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@staticmethod
|
||||
def if_(i: int) -> str:
|
||||
return 'if' if i == 0 else 'else if'
|
||||
|
||||
def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
inners = ""
|
||||
i = 0
|
||||
for trait in traits:
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners)
|
||||
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
@@ -730,21 +694,16 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if (spad1 == "f") and (spad == "t" or mode == "group"):
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
|
||||
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
@@ -808,13 +767,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list)
|
||||
(output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api)
|
||||
for k in kernels_dot_do_o:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
for k in kernels_convert_dq:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
for k in kernels_dq_dk_dv:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
|
||||
|
||||
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
|
||||
|
||||
21
example/ck_tile/01_fmha/codegen/utils.py
Normal file
21
example/ck_tile/01_fmha/codegen/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import os.path as path
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -357,31 +357,25 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
@@ -74,6 +75,7 @@
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
|
||||
{
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
//
|
||||
printf("up_lengths_:");
|
||||
print(up_lengths_);
|
||||
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
printf("up_lengths_: ");
|
||||
print(pt.get_upper_lengths());
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
|
||||
{
|
||||
printf("pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(p.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(p.left_pad_length_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(p.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
struct left_pad
|
||||
{
|
||||
@@ -330,24 +326,20 @@ struct left_pad
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("left_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
|
||||
{
|
||||
printf("left_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(lp.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(lp.left_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
struct right_pad : public base_transform<1, 1>
|
||||
{
|
||||
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("right_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
|
||||
{
|
||||
printf("right_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(rp.up_lengths_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(rp.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
|
||||
// UpLengths and Coefficients can be either of the followings:
|
||||
// 1) Tuple of index_t, which is known at run-time, or
|
||||
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<Coefficients>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("coefficients_: ");
|
||||
print(coefficients_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, typename Coefficients>
|
||||
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
|
||||
{
|
||||
printf("embed{");
|
||||
printf("up_lengths_: ");
|
||||
print(e.up_lengths_);
|
||||
printf(", coefficients_: ");
|
||||
print(e.coefficients_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
|
||||
{
|
||||
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
|
||||
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
|
||||
// will be very bad
|
||||
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Merge_v3_direct_division_mod{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("low_lengths_scan_ ");
|
||||
print(low_lengths_scan_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v3_division_mod{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", low_lengths_scan_: ");
|
||||
print(m.low_lengths_scan_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
{
|
||||
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("unmerge{");
|
||||
|
||||
//
|
||||
printf("up_lengths_");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_scan_");
|
||||
print(up_lengths_scan_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
|
||||
{
|
||||
printf("unmerge{");
|
||||
printf("up_lengths_: ");
|
||||
print(u.up_lengths_);
|
||||
printf(", up_lengths_scan_: ");
|
||||
print(u.up_lengths_scan_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("freeze{");
|
||||
|
||||
//
|
||||
printf("low_idx_: ");
|
||||
print(low_idx_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
|
||||
{
|
||||
printf("freeze{");
|
||||
printf("low_idx_: ");
|
||||
print(f.low_idx_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// insert a dangling upper dimension without lower dimension
|
||||
template <typename UpperLength>
|
||||
struct insert : public base_transform<0, 1>
|
||||
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpperLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("insert{");
|
||||
|
||||
//
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpperLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
|
||||
{
|
||||
printf("insert{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// replicate the original tensor and create a higher dimensional tensor
|
||||
template <typename UpLengths>
|
||||
struct replicate : public base_transform<0, UpLengths::size()>
|
||||
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("replicate{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//
|
||||
UpLengths up_lengths_;
|
||||
};
|
||||
|
||||
template <typename UpLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
|
||||
{
|
||||
printf("replicate{");
|
||||
printf("up_lengths_: ");
|
||||
print(r.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
struct slice : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
|
||||
ck_tile::is_known_at_compile_time<SliceEnd>::value;
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("slice{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_begin_: ");
|
||||
print(slice_begin_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_end_: ");
|
||||
print(slice_end_);
|
||||
|
||||
printf("}");
|
||||
} // namespace ck
|
||||
}; // namespace ck
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
|
||||
{
|
||||
printf("slice{");
|
||||
printf("up_lengths_: ");
|
||||
print(s.up_lengths_);
|
||||
printf(", slice_begin_: ");
|
||||
print(s.slice_begin_);
|
||||
printf(", slice_end_: ");
|
||||
print(s.slice_end_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief lower_idx = upper_idx % modulus.
|
||||
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Modulus{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Modulus, typename UpLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
|
||||
{
|
||||
printf("modulo{");
|
||||
printf("modulus_: ");
|
||||
print(m.modulus_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("xor_t{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
|
||||
{
|
||||
printf("xor_t{");
|
||||
printf("up_lengths_: ");
|
||||
print(x.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
struct offset : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<OffsetLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("offset{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("offset_length_: ");
|
||||
print(offset_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
|
||||
{
|
||||
printf("offset{");
|
||||
printf("up_lengths_: ");
|
||||
print(o.up_lengths_);
|
||||
printf(", offset_length_: ");
|
||||
print(o.offset_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
|
||||
{
|
||||
printf("indexing{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf(", iadaptor_: ");
|
||||
print(i.iadaptor_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
|
||||
@@ -77,6 +77,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to convert enum to string
|
||||
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
|
||||
{
|
||||
switch(pattern)
|
||||
{
|
||||
case tile_distribution_pattern::thread_raked: return "thread_raked";
|
||||
case tile_distribution_pattern::warp_raked: return "warp_raked";
|
||||
case tile_distribution_pattern::block_raked: return "block_raked";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern_to_string(DistributionPattern));
|
||||
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
|
||||
PatternType::Y0,
|
||||
PatternType::Y1,
|
||||
PatternType::Y2,
|
||||
PatternType::X0,
|
||||
PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Helper function to convert address space enum to string
|
||||
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
|
||||
{
|
||||
switch(addr_space)
|
||||
{
|
||||
case address_space_enum::generic: return "generic";
|
||||
case address_space_enum::global: return "global";
|
||||
case address_space_enum::lds: return "lds";
|
||||
case address_space_enum::sgpr: return "sgpr";
|
||||
case address_space_enum::constant: return "constant";
|
||||
case address_space_enum::vgpr: return "vgpr";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -177,9 +177,27 @@ struct array<T, 0>
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
|
||||
{
|
||||
printf("array{size: %ld, data: [", static_cast<long>(N));
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(a[i]);
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
|
||||
{
|
||||
printf("array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
|
||||
@@ -139,26 +139,21 @@ struct map
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename key, typename data, index_t max_size>
|
||||
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
|
||||
{
|
||||
printf("map{size_: %d, impl_: [", m.size_);
|
||||
for(const auto& [k, d] : m)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,13 +9,10 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t, index_t, index_t>
|
||||
struct static_for;
|
||||
|
||||
template <index_t...>
|
||||
struct sequence;
|
||||
|
||||
@@ -196,15 +193,24 @@ struct sequence
|
||||
{
|
||||
return sequence<f(Is)...>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
printf("sequence{size: %d, data: [", size());
|
||||
((printf("%d ", Is)), ...);
|
||||
printf("]}");
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
|
||||
{
|
||||
printf("sequence<");
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
(([&first](index_t value) {
|
||||
printf("%s%d", first ? "" : ", ", value);
|
||||
first = false;
|
||||
}(Is)),
|
||||
...);
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename T, T... Ints>
|
||||
struct __integer_sequence;
|
||||
|
||||
@@ -300,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
template <typename... T>
|
||||
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
|
||||
{
|
||||
printf("tuple<");
|
||||
if constexpr(sizeof...(T) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
|
||||
if(!first)
|
||||
printf(", ");
|
||||
print(t.get(i));
|
||||
first = false;
|
||||
});
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>>
|
||||
struct vector_traits<tuple<T...>, void>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
|
||||
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Unsigned representation of a conventional biased Float32 exponent.
|
||||
*
|
||||
* bias = 127;
|
||||
*
|
||||
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
|
||||
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
|
||||
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
|
||||
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
|
||||
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
|
||||
* E8M0_MIN = 0b00000000; => 2^-127
|
||||
* E8M0_MAX = 0b11111110; => 2^127
|
||||
* E8M0_NAN = 0b11111111; => NaN
|
||||
*/
|
||||
|
||||
struct e8m0_bexp_t
|
||||
{
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
|
||||
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
|
||||
|
||||
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
|
||||
};
|
||||
|
||||
using e8m0_t = e8m0_bexp_t;
|
||||
using e8m0_raw_t = typename e8m0_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_t>
|
||||
{
|
||||
using bitwise_type = e8m0_raw_t;
|
||||
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_t>
|
||||
{
|
||||
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
|
||||
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
|
||||
static constexpr e8m0_raw_t binary_nan = 0b11111111;
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
|
||||
{
|
||||
using traits = numeric_traits<float>;
|
||||
if(data == numeric<e8m0_t>::binary_nan)
|
||||
{
|
||||
return traits::NaN;
|
||||
}
|
||||
else if(data == 0)
|
||||
{
|
||||
return std::numeric_limits<float>::min();
|
||||
}
|
||||
else
|
||||
{
|
||||
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -19,14 +19,18 @@ struct constant
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <auto v>
|
||||
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
|
||||
{
|
||||
printf("%ld", static_cast<long>(v));
|
||||
}
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
|
||||
@@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits<T>
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename T::raw_type;
|
||||
using raw_type = typename traits::bitwise_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr int get_exponent(raw_type x)
|
||||
static constexpr raw_type get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr raw_type get_exponent(const T& x)
|
||||
{
|
||||
return get_exponent(bit_cast<raw_type>(x));
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
@@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits<T>
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(uint32_t i = 0; i < traits::mant; ++i)
|
||||
for(raw_type i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
@@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits<T>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
static constexpr int e8m0_bias = 127; // TODO: make it generic.
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
using utils = numeric_utils<T>;
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
|
||||
return std::ldexp(sign * mant * scale, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
value /= scale;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
@@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
static constexpr int exponent = 2;
|
||||
static constexpr int mantissa = 1;
|
||||
static constexpr int bias = 1;
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
@@ -41,18 +38,27 @@ struct pk_float4_e2m1_t
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
|
||||
: data{float_to_e2m1(init, scale)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
|
||||
@@ -191,131 +197,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data);
|
||||
return impl::_from_f4<bf16_t>(data, scale);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data);
|
||||
return impl::_from_f4<bf16x2_t>(data, scale);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x);
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data);
|
||||
return impl::_from_f4<fp32_t>(data, scale);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data);
|
||||
return impl::_from_f4<fp32x2_t>(data, scale);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data);
|
||||
return impl::_from_f4<fp16_t>(data, scale);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data);
|
||||
return impl::_from_f4<fp16x2_t>(data, scale);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
return e2m1_to_fp32_table[data & 0xf];
|
||||
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
|
||||
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
return e2m1_to_fp16_table[data & 0xf];
|
||||
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
|
||||
return fp16x2_t{
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
|
||||
|
||||
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
|
||||
float scale) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, scale); \
|
||||
} \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, 1.f); \
|
||||
}
|
||||
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
|
||||
#undef CK_TILE_SCALED_TYPE_CONVERT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T, typename>
|
||||
template <typename T, typename = void>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
@@ -94,7 +94,7 @@ struct vector_traits
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
static constexpr index_t vector_size = N;
|
||||
|
||||
@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: generic, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Global
|
||||
@@ -757,28 +735,6 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Global, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: LDS
|
||||
@@ -1138,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Lds, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Vgpr
|
||||
@@ -1313,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Vgpr, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
@@ -1360,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
p, buffer_size, invalid_element_value};
|
||||
}
|
||||
|
||||
// Generalized print function for all buffer_view variants
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
|
||||
T,
|
||||
BufferSizeType,
|
||||
InvalidElementUseNumericalZeroValue,
|
||||
Coherence>& bv)
|
||||
{
|
||||
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
|
||||
address_space_to_string(BufferAddressSpace),
|
||||
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
|
||||
print(bv.buffer_size_);
|
||||
printf(", invalid_element_value_: ");
|
||||
print(bv.invalid_element_value_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -53,10 +53,13 @@ struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
return is_null_tile_window_v<remove_cvref_t<T>>;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
|
||||
@@ -305,42 +305,45 @@ struct tensor_adaptor
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
BottomDimensionHiddenIds,
|
||||
TopDimensionHiddenIds>& adaptor)
|
||||
{
|
||||
printf("tensor_adaptor{\n");
|
||||
printf(" transforms: [");
|
||||
print(adaptor.get_transforms());
|
||||
printf("],\n");
|
||||
|
||||
printf(" LowerDimensionHiddenIds: [");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" UpperDimensionHiddenIds: [");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" BottomDimensionHiddenIds: [");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf("],\n");
|
||||
|
||||
//
|
||||
printf(" TopDimensionHiddenIds: [");
|
||||
print(TopDimensionHiddenIds{});
|
||||
printf("]\n}\n");
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
|
||||
@@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths,
|
||||
typename GuaranteedVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>& descriptor)
|
||||
{
|
||||
printf("tensor_descriptor{\n");
|
||||
// first print the tensor adaptor part of the descriptor using the base class print
|
||||
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n",
|
||||
static_cast<long>(descriptor.get_element_space_size().value));
|
||||
printf("guaranteed_vector_lengths: ");
|
||||
print(GuaranteedVectorLengths{});
|
||||
printf(",\nguaranteed_vector_strides: ");
|
||||
print(GuaranteedVectorStrides{});
|
||||
printf("}\n}\n");
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
|
||||
@@ -228,24 +228,6 @@ struct tile_distribution
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
@@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
|
||||
Ys2DDescriptor_,
|
||||
StaticTileDistributionEncoding_,
|
||||
TileDistributionDetail_>& distribution)
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(StaticTileDistributionEncoding_{});
|
||||
printf(", ");
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(distribution.ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
printf("ys_to_d_: ");
|
||||
print(distribution.ys_to_d_);
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -428,109 +428,7 @@ struct tile_distribution_encoding
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename encoding, typename shuffle>
|
||||
@@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution_encoding::detail
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
print(const typename tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>::detail& detail_obj)
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("ndim_span_major_: ");
|
||||
print(detail_obj.ndim_span_major_);
|
||||
printf(", ");
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(detail_obj.ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(detail_obj.max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
printf("rhs_lengthss_: ");
|
||||
print(detail_obj.rhs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ys_lengths_: ");
|
||||
print(detail_obj.ys_lengths_);
|
||||
printf(", ");
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(detail_obj.rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
printf("ndims_span_minor_: ");
|
||||
print(detail_obj.ndims_span_minor_);
|
||||
printf(", ");
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(detail_obj.max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_major_: ");
|
||||
print(detail_obj.ys_to_span_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(detail_obj.ys_to_span_minor_);
|
||||
printf(", ");
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(detail_obj.distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(detail_obj.ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(detail_obj.ps_over_rs_derivative_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Free print function for tile_distribution_encoding
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>& encoding)
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
|
||||
printf("rs_lengths_: ");
|
||||
print(encoding.rs_lengths_);
|
||||
printf(", ");
|
||||
printf("hs_lengthss_: ");
|
||||
print(encoding.hs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(encoding.ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(encoding.ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(encoding.ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(encoding.ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
76
include/ck_tile/core/utility/print.hpp
Normal file
76
include/ck_tile/core/utility/print.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
|
||||
/// can be printed.
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const T&)
|
||||
{
|
||||
static_assert(sizeof(T) == 0,
|
||||
"No print implementation available for this type. Please specialize "
|
||||
"ck_tile::print for your type.");
|
||||
}
|
||||
|
||||
/// Specialization for int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const int& value)
|
||||
{
|
||||
printf("%d", value);
|
||||
}
|
||||
|
||||
/// Specialization for float
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const float& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for double
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const double& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for long
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const long& value)
|
||||
{
|
||||
printf("%ld", value);
|
||||
}
|
||||
|
||||
/// Specialization for unsigned int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
|
||||
{
|
||||
printf("%u", value);
|
||||
}
|
||||
|
||||
/// Specialization for char
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const char& value)
|
||||
{
|
||||
printf("%c", value);
|
||||
}
|
||||
|
||||
/// Specialization for array
|
||||
template <typename T, size_t N>
|
||||
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
|
||||
{
|
||||
printf("[");
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(value[i]); // Recursively call print for each element
|
||||
}
|
||||
printf("]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -409,7 +409,13 @@ struct HostTensor
|
||||
}
|
||||
|
||||
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
|
||||
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
|
||||
void SetZero()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, e8m0_t>)
|
||||
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
|
||||
else
|
||||
std::fill(mData.begin(), mData.end(), 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
|
||||
@@ -24,8 +24,8 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
|
||||
@@ -52,8 +52,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
@@ -85,8 +83,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadSeqLenK) n += "sk";
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
@@ -100,7 +96,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
|
||||
@@ -1221,7 +1217,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -1232,7 +1228,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1244,22 +1240,15 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
|
||||
const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
|
||||
|
||||
const auto d_dram = [&]() {
|
||||
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
d_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
|
||||
|
||||
const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
do_ptr,
|
||||
@@ -1270,7 +1259,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto do_dram = pad_tensor_view(
|
||||
do_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
@@ -1313,7 +1302,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
@@ -1341,7 +1330,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
@@ -1376,9 +1365,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
|
||||
@@ -1406,9 +1394,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(dbias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
|
||||
@@ -1495,9 +1482,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
randval_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
|
||||
@@ -1550,7 +1536,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dk_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
auto dv_dram = [&]() {
|
||||
@@ -1564,7 +1550,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dv_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto dk_dram_window = make_tile_window(
|
||||
|
||||
@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
@@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr_iglp";
|
||||
|
||||
@@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
@@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
|
||||
public:
|
||||
using type = std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
|
||||
{
|
||||
public:
|
||||
static constexpr const char* name = "auto";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,15 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KRKTRVR_IGLP = 0,
|
||||
KRKTRVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
|
||||
@@ -21,4 +21,5 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
# add_subdirectory(layernorm2d)
|
||||
# add_subdirectory(rmsnorm2d)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
@@ -3,6 +3,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
|
||||
endif()
|
||||
|
||||
if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)
|
||||
|
||||
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using ck_tile::bf16_t;
|
||||
using ck_tile::bf16x2_t;
|
||||
using ck_tile::fp16_t;
|
||||
using ck_tile::fp16x2_t;
|
||||
using ck_tile::fp32_t;
|
||||
using ck_tile::fp32x2_t;
|
||||
using ck_tile::number;
|
||||
using ck_tile::pk_fp4_t;
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert();
|
||||
|
||||
using ck_tile::e8m0_raw_t;
|
||||
using ck_tile::e8m0_t;
|
||||
|
||||
TEST(OCP_Scale, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::has_inf(), false);
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::zero(), ck_tile::numeric<e8m0_t>::signaling_NaN());
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::min(), e8m0_t{e8m0_raw_t{0b00000000}});
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::max(), e8m0_t{e8m0_raw_t{0b11111110}});
|
||||
}
|
||||
TEST(OCP_Scale, NumericBasic)
|
||||
{
|
||||
auto scale_1 = e8m0_t{1.0f};
|
||||
auto scale_2 = e8m0_t{e8m0_raw_t{ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^0
|
||||
EXPECT_EQ(scale_1, scale_2);
|
||||
|
||||
auto scale_3 = e8m0_t{8.0f};
|
||||
auto scale_4 = e8m0_t{e8m0_raw_t{3 + ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^3
|
||||
EXPECT_EQ(scale_3, scale_4);
|
||||
}
|
||||
|
||||
TEST(OCP_Scale, ScaledConvertDevice)
|
||||
{
|
||||
constexpr bool is_device = true;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, ScaledConvertHost)
|
||||
{
|
||||
constexpr bool is_device = false;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, tensorInit)
|
||||
{
|
||||
using scale_t = e8m0_t;
|
||||
ck_tile::HostTensor<scale_t> scales({10, 10});
|
||||
ck_tile::FillUniformDistribution<scale_t>{1.f, 1.f}(scales);
|
||||
scales.SetZero();
|
||||
}
|
||||
|
||||
#define toPF4(x, y) ck_tile::scaled_type_convert<pk_fp4_t>(x, y)
|
||||
#define toDST(x, y) ck_tile::scaled_type_convert<DST>(x, y)
|
||||
#define toDSTx2(x, y) ck_tile::scaled_type_convert<DSTx2_t>(x, y)
|
||||
|
||||
#define toF32(x) ck_tile::type_convert<float>(x)
|
||||
#define toPF4_(x) ck_tile::type_convert<pk_fp4_t>(x)
|
||||
#define toSRC(x) ck_tile::type_convert<SRC>(x)
|
||||
#define toDST_(x) ck_tile::type_convert<DST>(x)
|
||||
|
||||
template <typename Kernel, typename... Args>
|
||||
__global__ void MyKernel(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
template <typename SRC, typename DST, int N>
|
||||
struct SrcPkfp4Dst
|
||||
{
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const SRC* src, DST* dst, e8m0_t scale1, e8m0_t scale2) const
|
||||
{
|
||||
|
||||
using SRCx2_t = ck_tile::ext_vector_t<SRC, 2>;
|
||||
using DSTx2_t = ck_tile::ext_vector_t<DST, 2>;
|
||||
|
||||
ck_tile::static_for<0, N, 2>{}([&](auto i) {
|
||||
const auto input2 = SRCx2_t{src[i], src[i + 1]};
|
||||
|
||||
if(i % 4 == 0)
|
||||
{
|
||||
// ex: fp32_t -> fp4 -> bf16_t
|
||||
dst[i] = toDST(toPF4(src[i], scale1), scale2);
|
||||
// ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t
|
||||
dst[i + 1] = toDST(toPF4_(toPF4(input2, scale1).unpack(number<1>{})), scale2);
|
||||
}
|
||||
else
|
||||
{
|
||||
// ex: fp32x2_t -> pk_fp4_t -> bf16x2_t
|
||||
reinterpret_cast<DSTx2_t*>(dst)[i >> 1] = toDSTx2(toPF4(input2, scale1), scale2);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert()
|
||||
{
|
||||
const auto test_data = std::array{4.f, 6.f, 8.f, 10.f};
|
||||
const auto ref_data = std::array{8.f, 16.f, 16.f, 16.f};
|
||||
const auto scale1 = e8m0_t{8.0f};
|
||||
const auto scale2 = e8m0_t{16.0f};
|
||||
|
||||
static_assert(test_data.size() == ref_data.size());
|
||||
static_assert(test_data.size() % 2 == 0);
|
||||
|
||||
constexpr int N = test_data.size();
|
||||
std::array<SRC, N> in;
|
||||
std::array<DST, N> ref, out;
|
||||
|
||||
// prepare input and ground truth in host
|
||||
for(int i = 0; i < N; ++i)
|
||||
{
|
||||
in[i] = toSRC(test_data[i]);
|
||||
ref[i] = toDST_(ref_data[i]);
|
||||
EXPECT_EQ(test_data[i], toF32(in[i]));
|
||||
EXPECT_EQ(ref_data[i], toF32(ref[i]));
|
||||
}
|
||||
|
||||
using job = SrcPkfp4Dst<SRC, DST, N>;
|
||||
|
||||
if constexpr(is_device)
|
||||
{
|
||||
auto in_d = std::make_unique<ck_tile::DeviceMem>(in.size() * sizeof(SRC));
|
||||
auto out_d = std::make_unique<ck_tile::DeviceMem>(out.size() * sizeof(DST));
|
||||
in_d->ToDevice(in.data());
|
||||
|
||||
MyKernel<job><<<1, 1>>>(reinterpret_cast<const SRC*>(in_d->GetDeviceBuffer()),
|
||||
reinterpret_cast<DST*>(out_d->GetDeviceBuffer()),
|
||||
scale1,
|
||||
scale2);
|
||||
|
||||
out_d->FromDevice(out.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
job{}(in.data(), out.data(), scale1, scale2);
|
||||
}
|
||||
|
||||
for(int i = 0; i < N; ++i)
|
||||
EXPECT_EQ(ref[i], out[i]) << "i:" << i;
|
||||
}
|
||||
4
test/ck_tile/utility/CMakeLists.txt
Normal file
4
test/ck_tile/utility/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
message("-- Adding: test/ck_tile/utility/")
|
||||
|
||||
# Add print tests
|
||||
add_subdirectory(print)
|
||||
8
test/ck_tile/utility/print/CMakeLists.txt
Normal file
8
test/ck_tile/utility/print/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# Print utility tests
|
||||
add_gtest_executable(test_print_sequence test_print_sequence.cpp)
|
||||
add_gtest_executable(test_print_array test_print_array.cpp)
|
||||
add_gtest_executable(test_print_tuple test_print_tuple.cpp)
|
||||
add_gtest_executable(test_print_coordinate_transform test_print_coordinate_transform.cpp)
|
||||
add_gtest_executable(test_print_static_encoding_pattern test_print_static_encoding_pattern.cpp)
|
||||
add_gtest_executable(test_print_buffer_view test_print_buffer_view.cpp)
|
||||
add_gtest_executable(test_print_basic_types test_print_basic_types.cpp)
|
||||
70
test/ck_tile/utility/print/README.md
Normal file
70
test/ck_tile/utility/print/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Print Function Tests
|
||||
|
||||
This directory contains unit tests for testing the print functionality of various data structures and coordinate transformations in the composable_kernel library.
|
||||
|
||||
## Tests Included
|
||||
|
||||
### test_print_sequence.cpp
|
||||
Tests the print functionality for `sequence<...>` containers:
|
||||
- Simple sequences with multiple elements
|
||||
- Single element sequences
|
||||
- Empty sequences
|
||||
- Longer sequences
|
||||
|
||||
### test_print_array.cpp
|
||||
Tests the print functionality for `array<T, N>` containers:
|
||||
- Arrays with integer values
|
||||
- Single element arrays
|
||||
- Empty arrays (size 0)
|
||||
- Arrays with floating point values
|
||||
|
||||
### test_print_tuple.cpp
|
||||
Tests the print functionality for `tuple<...>` containers:
|
||||
- Simple tuples with numbers
|
||||
- Single element tuples
|
||||
- Empty tuples
|
||||
- Mixed type tuples
|
||||
|
||||
### test_print_coordinate_transform.cpp
|
||||
Tests the print functionality for coordinate transformation structures:
|
||||
- `pass_through` transform
|
||||
- `embed` transform
|
||||
- `merge` transform
|
||||
- `unmerge` transform
|
||||
- `freeze` transform
|
||||
|
||||
## Testing Approach
|
||||
|
||||
All tests use Google Test's `CaptureStdout()` functionality to capture the output from print functions and verify the formatting:
|
||||
|
||||
```cpp
|
||||
testing::internal::CaptureStdout();
|
||||
print(object);
|
||||
std::string output = testing::internal::GetCapturedStdout();
|
||||
EXPECT_EQ(output, "expected_format");
|
||||
```
|
||||
|
||||
This approach enables testing of print function output without affecting the console during test execution.
|
||||
|
||||
## Building and Running
|
||||
|
||||
The tests are integrated into the CMake build system. To build and run the print tests:
|
||||
|
||||
```bash
|
||||
# Build the specific test
|
||||
make test_print_sequence
|
||||
|
||||
# Run the test
|
||||
./test_print_sequence
|
||||
|
||||
# Or run all print tests using CTest
|
||||
ctest -R "test_print"
|
||||
```
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
To add tests for new data structures:
|
||||
|
||||
1. Create a new test file: `test_print_<structure_name>.cpp`
|
||||
2. Follow the existing pattern using `CaptureStdout()`
|
||||
3. Add the test executable to `CMakeLists.txt`
|
||||
59
test/ck_tile/utility/print/test_print_array.cpp
Normal file
59
test/ck_tile/utility/print/test_print_array.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintArrayTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintArrayTest, PrintIntArray)
|
||||
{
|
||||
// Test printing array<int, 3>
|
||||
array<int, 3> arr{10, 20, 30};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// The expected format should match the array print function implementation
|
||||
EXPECT_EQ(output, "array{size: 3, data: [10, 20, 30]}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintSingleElementArray)
|
||||
{
|
||||
// Test printing array<int, 1>
|
||||
array<int, 1> arr{42};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "array{size: 1, data: [42]}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintEmptyArray)
|
||||
{
|
||||
// Test printing array<int, 0> (empty array)
|
||||
array<int, 0> arr{};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintFloatArray)
|
||||
{
|
||||
// Test printing array with float values
|
||||
array<float, 2> arr{3.14f, 2.71f};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// Note: float printing format may vary, so we'll test for basic structure
|
||||
EXPECT_TRUE(output.find("array{size: 2, data: [") == 0);
|
||||
EXPECT_TRUE(output.find("3.14") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("2.71") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("]}") == output.length() - 2);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
76
test/ck_tile/utility/print/test_print_basic_types.cpp
Normal file
76
test/ck_tile/utility/print/test_print_basic_types.cpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintBasicTypesTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintIntArray)
|
||||
{
|
||||
int arr[4] = {1, 2, 3, 4};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[1, 2, 3, 4]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintFloatArray)
|
||||
{
|
||||
float arr[3] = {1.5f, 2.5f, 3.5f};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// Note: floating point formatting may vary, so we check for key elements
|
||||
EXPECT_TRUE(output.find("[") == 0);
|
||||
EXPECT_TRUE(output.find("1.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("2.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("3.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == ']');
|
||||
EXPECT_TRUE(output.find(", ") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintDoubleArray)
|
||||
{
|
||||
double arr[2] = {10.123, 20.456};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_TRUE(output.find("[") == 0);
|
||||
EXPECT_TRUE(output.find("10.123") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("20.456") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == ']');
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintUnsignedIntArray)
|
||||
{
|
||||
unsigned int arr[3] = {100u, 200u, 300u};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[100, 200, 300]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintCharArray)
|
||||
{
|
||||
char arr[5] = {'a', 'b', 'c', 'd', 'e'};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[a, b, c, d, e]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintSingleElementArray)
|
||||
{
|
||||
int arr[1] = {42};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[42]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
78
test/ck_tile/utility/print/test_print_buffer_view.cpp
Normal file
78
test/ck_tile/utility/print/test_print_buffer_view.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintBufferViewTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintGenericBufferView)
|
||||
{
|
||||
// Test printing generic address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::generic>(&data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: generic") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintGlobalBufferView)
|
||||
{
|
||||
// Test printing global address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::global>(&data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: global") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintLdsBufferView)
|
||||
{
|
||||
// Test printing LDS address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::lds>(data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: lds") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintVgprBufferView)
|
||||
{
|
||||
// Test printing VGPR address space buffer_view
|
||||
float data[4] = {1.5f, 2.5f, 3.5f, 4.5f};
|
||||
auto bv = make_buffer_view<address_space_enum::vgpr>(data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: vgpr") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
25
test/ck_tile/utility/print/test_print_common.hpp
Normal file
25
test/ck_tile/utility/print/test_print_common.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gtest/gtest-spi.h>
|
||||
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
class PrintTest : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
// Helper function to capture and return the output of a print function
|
||||
template <typename T>
|
||||
std::string CapturePrintOutput(const T& type)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
testing::internal::CaptureStdout();
|
||||
print(type);
|
||||
return testing::internal::GetCapturedStdout();
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintCoordinateTransformTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintPassThrough)
|
||||
{
|
||||
// Test printing pass_through transform
|
||||
auto pt = make_pass_through_transform(number<32>{});
|
||||
|
||||
std::string output = CapturePrintOutput(pt);
|
||||
|
||||
// Verify it contains the pass_through identifier and some structure
|
||||
EXPECT_TRUE(output.find("pass_through{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintEmbed)
|
||||
{
|
||||
// Test printing embed transform
|
||||
auto embed_transform = make_embed_transform(make_tuple(number<4>{}, number<8>{}),
|
||||
make_tuple(number<1>{}, number<4>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(embed_transform);
|
||||
|
||||
// Verify it contains the embed identifier and key fields
|
||||
EXPECT_TRUE(output.find("embed{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("coefficients_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintMerge)
|
||||
{
|
||||
// Test printing merge transform
|
||||
auto merge_transform = make_merge_transform(make_tuple(number<4>{}, number<8>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(merge_transform);
|
||||
|
||||
// Verify it contains merge identifier and key fields
|
||||
EXPECT_TRUE(output.find("merge") ==
|
||||
0); // Could be merge_v2_magic_division or merge_v3_division_mod
|
||||
EXPECT_TRUE(output.find("low_lengths_") != std::string::npos ||
|
||||
output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintUnmerge)
|
||||
{
|
||||
// Test printing unmerge transform
|
||||
auto unmerge_transform = make_unmerge_transform(make_tuple(number<4>{}, number<8>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(unmerge_transform);
|
||||
|
||||
// Verify it contains the unmerge identifier and key fields
|
||||
EXPECT_TRUE(output.find("unmerge{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintFreeze)
|
||||
{
|
||||
// Test printing freeze transform
|
||||
auto freeze_transform = make_freeze_transform(number<5>{});
|
||||
|
||||
std::string output = CapturePrintOutput(freeze_transform);
|
||||
|
||||
// Verify it contains the freeze identifier and key fields
|
||||
EXPECT_TRUE(output.find("freeze{") == 0);
|
||||
EXPECT_TRUE(output.find("low_idx_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
45
test/ck_tile/utility/print/test_print_sequence.cpp
Normal file
45
test/ck_tile/utility/print/test_print_sequence.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintSequenceTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintSimpleSequence)
|
||||
{
|
||||
// Test printing sequence<1, 5, 8>
|
||||
constexpr auto seq = sequence<1, 5, 8>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
// Verify the output format
|
||||
EXPECT_EQ(output, "sequence<1, 5, 8>");
|
||||
}
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintSingleElementSequence)
|
||||
{
|
||||
// Test printing sequence<42>
|
||||
constexpr auto seq = sequence<42>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
EXPECT_EQ(output, "sequence<42>");
|
||||
}
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintEmptySequence)
|
||||
{
|
||||
// Test printing sequence<> (empty sequence)
|
||||
constexpr auto seq = sequence<>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
EXPECT_EQ(output, "sequence<>");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintStaticEncodingPatternTest : public PrintTest
|
||||
{
|
||||
protected:
|
||||
void TestY0Y1Y2(const std::string& output, auto Y0, auto Y1, auto Y2)
|
||||
{
|
||||
std::stringstream expected;
|
||||
expected << "<Y0, Y1, Y2>: <" << Y0 << ", " << Y1 << ", " << Y2 << ">";
|
||||
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
|
||||
}
|
||||
void TestX0X1(const std::string& output, auto X0, auto X1)
|
||||
{
|
||||
std::stringstream expected;
|
||||
expected << "<X0, X1>: <" << X0 << ", " << X1 << ">";
|
||||
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern)
|
||||
{
|
||||
// Test printing thread raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:4") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("thread_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern)
|
||||
{
|
||||
// Test printing warp raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:8") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("warp_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern)
|
||||
{
|
||||
// Test printing block raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("block_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
66
test/ck_tile/utility/print/test_print_tuple.cpp
Normal file
66
test/ck_tile/utility/print/test_print_tuple.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintTupleTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintTupleTest, PrintSimpleTuple)
|
||||
{
|
||||
// Test printing tuple with numbers
|
||||
auto tup = make_tuple(number<1>{}, number<5>{}, number<8>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
// Verify the output format matches tuple print implementation
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("1") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("8") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintSingleElementTuple)
|
||||
{
|
||||
// Test printing tuple with single element
|
||||
auto tup = make_tuple(number<42>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("42") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintEmptyTuple)
|
||||
{
|
||||
// Test printing empty tuple
|
||||
auto tup = make_tuple();
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_EQ(output, "tuple<>");
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintMixedTypeTuple)
|
||||
{
|
||||
// Test printing tuple with mixed types (numbers and constants)
|
||||
auto tup = make_tuple(number<10>{}, constant<20>{}, number<30>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("10") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("20") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("30") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user