Merge branch 'develop' into lwpck-4181

This commit is contained in:
khushbu
2025-12-15 15:28:16 -05:00
166 changed files with 8195 additions and 2289 deletions

5
.gitignore vendored
View File

@@ -83,6 +83,11 @@ __pycache__/
.cache/
# Generated test data
test_data/*
!test_data/*.py
!test_data/*.sh
# Exceptions to build* patterns above
# The experimental/builder directory should be tracked despite matching build*
!experimental/builder

View File

@@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
### Changed

10
Jenkinsfile vendored
View File

@@ -1476,15 +1476,19 @@ pipeline {
setup_args = "NO_CK_BUILD"
execute_args = """ cd ../build && \
../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
make -j64 test_grouped_convnd_fwd_dataset_xdl \
test_grouped_convnd_bwd_data_dataset_xdl \
test_grouped_convnd_bwd_weight_dataset_xdl && \
cd ../test_data && \
# Dataset generation modes:
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
./generate_test_dataset.sh half && \
./generate_test_dataset.sh small && \
cd ../build && \
./bin/test_grouped_convnd_fwd_dataset_xdl"""
./bin/test_grouped_convnd_fwd_dataset_xdl && \
./bin/test_grouped_convnd_bwd_data_dataset_xdl && \
./bin/test_grouped_convnd_bwd_weight_dataset_xdl"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)

View File

@@ -1,2 +1,2 @@
rocm-docs-core[api_reference]==1.31.0
rocm-docs-core[api_reference]==1.31.1
sphinxcontrib-bibtex==2.6.5

View File

@@ -237,7 +237,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core[api-reference]==1.31.0
rocm-docs-core[api-reference]==1.31.1
# via -r requirements.in
rpds-py==0.24.0
# via

View File

@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
if(ck::is_gfx11_supported())
{
return ck::utils::check_err(e_ms_ns_device_result,
e_ms_ns_host_result,
"Error: Incorrect results!",
1e-4,
1e-4)
? 0
: 1;
}
else
{
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
}
return 0;

View File

@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
if(ck::is_gfx11_supported())
{
return ck::utils::check_err(e_ms_ns_device_result,
e_ms_ns_host_result,
"Error: Incorrect results!",
1e-4,
1e-4)
? 0
: 1;
}
else
{
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
}
return 0;

View File

@@ -65,7 +65,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
endif()
# generate a list of kernels, but not actually emit files at config sta

View File

@@ -76,7 +76,8 @@ using fmha_traits = ck_tile::TileFmhaTraits<{F_spad},
{F_dropout},
{F_qscale},
{F_occupancy},
{F_skip}>;
{F_skip},
{F_sink}>;
using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -113,7 +114,7 @@ using fmha_kernel = {F_kernel}<fmha_pipeline, fmha_epilogue>;
using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
template<>
float fmha_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
@@ -229,9 +230,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -278,13 +279,14 @@ class FmhaFwdApiTrait:
dvpad: str
skip: str
tr_load: str
sink: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)
@property
@@ -384,6 +386,7 @@ class FmhaFwdPipeline:
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_trload: str # true/false
F_sink: str # true/false
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -454,6 +457,10 @@ class FmhaFwdPipeline:
n += "_trload"
else:
n += "_ntrload"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -543,6 +550,7 @@ class FmhaFwdApiPool:
F_trload=BOOL_MAP[trait.tr_load],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
@@ -683,6 +691,7 @@ class FmhaFwdKernel:
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
)
@property
@@ -725,6 +734,7 @@ class FmhaFwdKernel:
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
tr_load=self.F_pipeline.F_trload,
sink=self.F_pipeline.F_sink,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@@ -957,52 +967,55 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
pipelines = []
if dtype in cls._DT_FP32:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["no", "pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass
@@ -1033,13 +1046,14 @@ class KernelComponentFactoryGfx950(
)
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
@@ -1048,15 +1062,15 @@ class KernelComponentFactoryGfx950(
and dropout == "f"
and skip == "f"
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for mask in ["no", "causal"]:
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines
@@ -1105,23 +1119,24 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
pipelines = []
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
return pipelines

View File

@@ -73,7 +73,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_pagedkv},
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>;
{F_occupancy},
{F_sink}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -117,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}} // anonymous namespace
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#pragma clang diagnostic push
@@ -279,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -332,6 +333,7 @@ class FmhaFwdSplitKVApiTrait:
dpad: str
dvpad: str
pagedkv: str
sink: str # sink or not
bn1comb: int # tile size along v head_dim of combine kernel
@property
@@ -339,7 +341,7 @@ class FmhaFwdSplitKVApiTrait:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
+ f"{self.dvpad}-{self.pagedkv}"
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
)
@property
@@ -425,6 +427,7 @@ class FmhaFwdSplitKVPipeline:
F_lse: str #
F_squant: str #
F_pagedkv: str # t/f
F_sink: str # t/f
F_mask: str # value from MASK_MAP
@property
@@ -485,6 +488,10 @@ class FmhaFwdSplitKVPipeline:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -567,6 +574,7 @@ class FmhaFwdSplitKVApiPool:
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
@@ -667,6 +675,7 @@ class FmhaFwdSplitKVKernel:
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
@@ -740,19 +749,23 @@ class KernelComponentFactoryBase:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
for logits, mask, bias, pagedkv, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
):
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
@@ -908,6 +921,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
@@ -917,6 +931,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
@@ -1075,6 +1090,7 @@ def write_blobs(
lse=kernel.F_pipeline.F_lse,
squant=kernel.F_pipeline.F_squant,
pagedkv=kernel.F_pipeline.F_pagedkv,
sink=kernel.F_pipeline.F_sink,
spad=kernel.F_pipeline.F_spad,
skpad=kernel.F_pipeline.F_skpad,
dpad=kernel.F_pipeline.F_dpad,

View File

@@ -65,7 +65,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip}>;
{F_skip},
{F_sink}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -100,7 +101,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
template<>
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
@@ -129,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -163,12 +164,13 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
skip: str
sink: str
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)
@property
@@ -256,6 +258,7 @@ class FmhaFwdPipeline:
F_squant: str #
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_sink: str # true/false
@property
def name(self) -> str:
@@ -320,6 +323,10 @@ class FmhaFwdPipeline:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -363,6 +370,7 @@ class FmhaFwdApiPool:
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_sink=BOOL_MAP[trait.sink],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
@@ -480,6 +488,7 @@ class FmhaFwdKernel:
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -526,6 +535,7 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
sink=self.F_pipeline.F_sink,
)
@@ -539,22 +549,23 @@ class KernelComponentFactoryBase:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, skip in itertools.product(
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t"],
["f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
pass # TODO
else:
@@ -678,6 +689,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
@@ -687,6 +699,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_fwd) integration

View File

@@ -266,6 +266,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
@@ -352,6 +353,7 @@ struct fmha_fwd_pagedkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
};
@@ -442,6 +444,7 @@ struct fmha_fwd_splitkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
};
@@ -561,6 +564,7 @@ struct fmha_batch_prefill_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
float p_drop;
@@ -613,6 +617,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
@@ -663,6 +668,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -824,6 +830,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
}
@@ -869,6 +876,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -935,6 +943,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
else
@@ -982,6 +991,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -1142,6 +1152,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -1194,6 +1205,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -1228,7 +1240,8 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1254,6 +1267,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1280,7 +1294,8 @@ template <ck_tile::index_t HDim_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_pagedkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1305,6 +1320,7 @@ struct fmha_fwd_pagedkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1327,6 +1343,7 @@ template <ck_tile::index_t HDim_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasSink_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
@@ -1354,6 +1371,7 @@ struct fmha_fwd_splitkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1440,6 +1458,7 @@ struct fmha_fwd_traits
bool has_dropout;
quant_scale_enum qscale_type;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
@@ -1458,6 +1477,7 @@ struct fmha_fwd_pagedkv_traits
bool use_pagedkv = true;
bool do_fp8_static_quant = false;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
@@ -1477,6 +1497,7 @@ struct fmha_fwd_splitkv_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,

View File

@@ -879,6 +879,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_sink = mask.sink > 0 ? true : false;
traits.has_lse = lse;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
@@ -1042,6 +1043,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.sink_size = mask.sink;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
@@ -1645,7 +1647,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k));
}
else
{
@@ -1657,6 +1659,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
mask.sink,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
@@ -1666,6 +1669,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
mask.sink,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));

View File

@@ -25,6 +25,7 @@ struct mask_info
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
ck_tile::index_t sink;
void serialize(std::ostream& os) const
{
@@ -58,13 +59,14 @@ struct mask_info
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
ck_tile::index_t sink_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
left_size, right_size, sink_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
@@ -79,27 +81,54 @@ struct mask_info
{
throw std::invalid_argument("invalid mask value: " + str);
}
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
auto found_2 = v.find(',', found_1 + 1);
ck_tile::index_t v1 = 0;
ck_tile::index_t sink = 0;
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
v0, v1, sink, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "b")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
v0, v1, sink, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "g")
{
@@ -108,6 +137,7 @@ struct mask_info
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
tmp.sink = 0;
}
}
else
@@ -126,6 +156,7 @@ struct mask_info
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else if(str == "2" || str == "b")
{
@@ -134,6 +165,7 @@ struct mask_info
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else
{

View File

@@ -0,0 +1,77 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
TEST_SPLITKV=0
TEST_APPENDKV=0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while getopts ":sa" opt; do
case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done
run_fp16_bf16_tests() {
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi
for prec in "fp16"; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for batch in 1 4; do
for head in 1; do
for h_k in 1; do
for q_seq in 128 512 ; do
for kv_seq in 128 1024; do
for hdim in 32 64 128 256; do #256
for lse in 0 1 ; do
for bias in "e" ; do
for p_drop in 0.0 0.2; do # 0.0
for mask in "t:2,0,4" "b:1,0,2"; do
for num_splits in $NUM_SPLITS ; do
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ; done
}
set -x
run_fp16_bf16_tests
set +x

View File

@@ -39,6 +39,7 @@ function print_log_header(){
#run verification tests
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"

View File

@@ -0,0 +1,86 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
set -x
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
# window_size[2,0], sink_size = 2
# x=1/y=3
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
# l=2/r=0(tl) l=2/r=0/s=2(tl)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
# x=4/y=1
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
# l=0/r=3(tl) l=0/r=3/s=2(tl)
# l=3/r=0(br) l=3/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
# x=4/y=-1
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
# l=1/r=0(br) l=1/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
# x=-1/y=5
# * * * * * * * * * * * *
# * * * * * * * * * * * *
# 1 * * * * * 1 * * * * *
# 1 1 * * * * 1 1 * * * *
# 1 1 1 * * * ----> 1 1 1 * * *
# * 1 1 1 * * 1 1 1 1 * *
# * * 1 1 1 * 1 1 1 1 1 *
# * * * 1 1 1 1 1 * 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
# x=-1/y=8
# * * * * * * * * * *
# * * * * * * * * * *
# 1 * * * * ----> 1 * * * *
# 1 1 * * * 1 1 * * *
# 1 1 1 * * 1 1 1 * *
# 1 1 1 1 * 1 1 1 1 *
# 1 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)

View File

@@ -12,40 +12,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -459,7 +433,7 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
auto create_args()
inline auto create_args()
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")

View File

@@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
return pass;
}
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
parse_gemm_size(ck_tile::ArgParser& arg_parser)
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_gemm_size(
ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");

View File

@@ -63,25 +63,30 @@ struct UniversalInvoker
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups,
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -11,40 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool kPadK = true;
@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

View File

@@ -11,24 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;

View File

@@ -10,40 +10,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType, bool Persistent>
@@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;

View File

@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
return combined_hash;
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
@@ -91,7 +58,7 @@ struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool kPadK = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleQuant = true;
};
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>

View File

@@ -391,20 +391,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
}
}
ck_tile::index_t AQK, BQK, BQN = 0;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
AQK = ck_tile::integer_divide_ceil(
K, QuantGroupSize::kK); // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{

View File

@@ -7,46 +7,46 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
struct GemmConfigBase
struct GemmConfigurationBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool PAD_M = true;
static constexpr bool PAD_N = true;
static constexpr bool PAD_K = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool PERMUTE_A = false;
static constexpr bool PERMUTE_B = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool TRANSPOSE_C = false;
static constexpr bool USE_STRUCTURED_SPARSITY = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int BLOCK_PER_CU = 1;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
static constexpr bool PRESHUFFLE = false;
static constexpr bool DOUBLE_SMEM_BUFFER = false;
};
template <typename PrecType, bool Persistent_>
struct GemmConfigMemoryInterwave : public GemmConfigBase
template <typename PrecisionType, bool IsPersistent>
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 16;
static constexpr ck_tile::index_t M_TILE = 256;
static constexpr ck_tile::index_t N_TILE = 256;
static constexpr ck_tile::index_t K_TILE = 16;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_WARP = 2;
static constexpr ck_tile::index_t N_WARP = 2;
static constexpr ck_tile::index_t K_WARP = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr ck_tile::index_t M_WARP_TILE = 32;
static constexpr ck_tile::index_t N_WARP_TILE = 32;
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
static constexpr bool Persistent = Persistent_;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PERSISTENT = IsPersistent;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
struct StreamKGemmTypeConfig
struct StreamKGemmTypeConfiguration
{
using ADataType = ADataType_;
using BDataType = BDataType_;
@@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig
using CDataType = CDataType_;
};
auto create_args(int argc, char* argv[])
auto createArgs(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")

View File

@@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout)
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
auto calculateRtolAtol(const ck_tile::index_t k_dim,
const ck_tile::index_t k_batch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
const auto relative_tolerance =
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(k_dim, k_batch));
const auto absolute_tolerance =
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
// Calculate error due to multiple WGs working in the same C macro tile
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
const auto relative_tolerance_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
const auto absolute_tolerance_split_k =
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
k_batch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
std::max(absolute_tolerance, absolute_tolerance_split_k));
}
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -45,102 +49,107 @@ template <typename GemmConfig,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s);
const ck_tile::stream_config& stream_config);
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
std::tuple<float, ck_tile::index_t>
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
ck_tile::DeviceMem& b_k_n_device_memory,
ck_tile::DeviceMem& c_m_n_device_memory,
ck_tile::index_t m_dim,
ck_tile::index_t n_dim,
ck_tile::index_t k_dim,
ck_tile::index_t stride_a,
ck_tile::index_t stride_b,
ck_tile::index_t stride_c,
int warmup_iterations,
int repeat_iterations,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C};
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
b_k_n_device_memory.GetDeviceBuffer(),
c_m_n_device_memory.GetDeviceBuffer(),
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
std::tuple<float, ck_tile::index_t> average_time_and_batch;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
else /*Reduction*/
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
return ave_time_and_batch;
return average_time_and_batch;
}
template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
const ck_tile::tuple<double, double>& rtol_atol,
const char* variant)
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
const char* variant)
{
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_ref,
bool pass = ck_tile::check_err(c_m_n_device_result,
c_m_n_reference,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
relative_absolute_tolerances.at(ck_tile::number<0>{}),
relative_absolute_tolerances.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "Relative error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
<< " Absolute error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
<< std::endl;
return pass;
}
ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy)
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
{
if(strategy == "atomic")
{
@@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string
}
}
template <typename GemmConfig,
typename TypeConfig,
template <typename GemmConfiguration,
typename TypeConfiguration,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
int runGemmExampleWithLayouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
static_assert(!GemmConfig::Preshuffle, "Not implemented");
static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented");
static_assert(!GemmConfig::PermuteA, "Not implemented");
static_assert(!GemmConfig::PermuteB, "Not implemented");
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
using ADataType = typename TypeConfiguration::ADataType;
using BDataType = typename TypeConfiguration::BDataType;
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
using CDataType = typename TypeConfiguration::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t m_dim = arg_parser.get_int("m");
ck_tile::index_t n_dim = arg_parser.get_int("n");
ck_tile::index_t k_dim = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
int warmup_iterations = arg_parser.get_int("warmup");
int repeat_iterations = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ADataType> a_m_k_host(
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_host(
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_device_result(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
a_m_k_host.SetZero();
b_k_n_host.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
a_m_k_device_memory.ToDevice(a_m_k_host.data());
b_k_n_device_memory.ToDevice(b_k_n_host.data());
c_m_n_device_memory.SetZero();
c_m_n_device_result.SetZero();
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
ADataType,
BDataType,
ck_tile::tuple<>,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_device_memory,
b_k_n_device_memory,
c_m_n_device_memory,
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c,
warmup_iterations,
repeat_iterations,
flush_cache,
reduction_strategy);
auto [ave_time, num_wgs_per_tile] = invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy);
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
sizeof(CDataType) * m_dim * n_dim;
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
float gb_per_sec = num_byte / 1.E6 / average_time;
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = false;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_ref.SetZero();
ck_tile::HostTensor<CDataType> c_m_n_reference(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
c_m_n_reference.SetZero();
if(arg_parser.get_int("v") == 1) // Validate on the CPU
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
a_m_k_host, b_k_n_host, c_m_n_reference);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
}
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
{
// Memory on device to store gpu reference result
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
c_m_n_reference.get_element_space_size_in_bytes());
c_m_n_gpu_buffer_reference.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
CLayout>(
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
}
return pass;

View File

@@ -4,11 +4,11 @@
#include "gemm_utils.hpp"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -17,43 +17,49 @@ template <typename GemmConfig,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s)
const ck_tile::stream_config& stream_config)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
GemmConfiguration::N_TILE,
GemmConfiguration::K_TILE>,
ck_tile::sequence<GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::K_WARP>,
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE>,
GemmConfiguration::PERMUTE_A,
GemmConfiguration::PERMUTE_B>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using TilePartitioner = ck_tile::
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
GemmConfiguration::PAD_N,
GemmConfiguration::PAD_K,
GemmConfiguration::DOUBLE_SMEM_BUFFER,
ALayout,
BLayout,
ELayout,
GemmConfiguration::TRANSPOSE_C,
GemmConfiguration::USE_STRUCTURED_SPARSITY,
GemmConfiguration::PERSISTENT,
GemmConfiguration::NUM_WAVE_GROUPS,
GemmConfiguration::PRESHUFFLE>;
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
@@ -61,39 +67,39 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
GemmConfiguration::NUM_WAVE_GROUPS>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
if(stream_config.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
@@ -109,7 +115,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
@@ -120,45 +126,47 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
std::function<void()> preprocess = reset_data_buffers;
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{ave_time, num_wgs_per_tile};
ck_tile::index_t num_wgs_per_tile =
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the same
// output tile in the C tensor, so we must atomic add
// the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the
// same output tile in the C tensor, so we must
// atomic add the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing final
// results to each macro tile in the C tensor, so we
// can do a set.
ck_tile::memory_operation_enum::set>{});
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing
// final results to each macro tile in the C
// tensor, so we can do a set.
ck_tile::memory_operation_enum::set>{});
}
}
#include "run_gemm_example.inc"
template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
template <typename GemmConfiguration, typename TypeConfiguration>
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig>(
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
argc, argv, Row{}, Col{}, Row{});
}
else
@@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
template <template <typename PreType, bool Persistent_> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
int runGemmExample(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_dp = arg_parser.get_bool("persistent_dp");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
if(persistent_dp)
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else
@@ -247,5 +257,5 @@ int run_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigMemoryInterwave>(argc, argv);
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
}

View File

@@ -12,19 +12,21 @@ This project is a prototype for a more general builder pattern for all of compos
## Design descriptions
- [CK Builder design description](include/ck_tile/builder/README.md)
- [CK Builder design description](include/ck_tile/builder/README.md)
- [CK Builder factory design](include/ck_tile/builder/factory/README.md)
- [CK Builder testing design](include/ck_tile/builder/testing/README.md)
## Directory Structure
- `include/ck_tile/builder/`
- `include/ck_tile/builder/`
Core builder headers and public API.
- `include/ck_tile/builder/reflect`
Reflection mechanism.
- `include/ck_tile/builder/factory`
Compile-time dispatch from builder descriptors to our exisitng specialized convolution kernel implementations.
- `test/`
- `test/`
Unit tests and example usage of the builder pattern.
- `CMakeLists.txt`
- `CMakeLists.txt`
CMake configuration for building the experimental builder and its tests.
## CMake Configuration

View File

@@ -153,6 +153,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unknown ConvFwdSpecialization";
}
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// @file
/// @brief Implementation of the describe() function template for convolution kernels
#pragma once
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck_tile/builder/reflect/conv_traits.hpp"
namespace ck_tile::reflect {
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have ConvTraits)
/// @return A ConvDescription object populated with the instance's configuration details
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.input_layout = Traits::layout[0],
.weight_layout = Traits::layout[1],
.output_layout = Traits::layout[2],
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding,
},
[]() { return reflect::instance_string<Instance>(); });
}
} // namespace ck_tile::reflect

View File

@@ -25,7 +25,7 @@
#include <functional>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/conv_types.hpp>
#include <ck_tile/builder/reflect/description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
@@ -249,41 +249,7 @@ class ConvDescription : public Description
GemmAlgorithmInfo algorithm_;
std::function<std::string()> instance_string_getter_;
};
} // namespace conv
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
/// @return A ConvDescription object populated with the instance's configuration details
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.input_layout = Traits::layout[0],
.weight_layout = Traits::layout[1],
.output_layout = Traits::layout[2],
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding,
},
[]() { return reflect::instance_string<Instance>(); });
}
} // namespace ck_tile::reflect

View File

@@ -10,8 +10,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/reflect/conv_types.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
#include "ck_tile/builder/types.hpp"
@@ -161,103 +161,19 @@ constexpr auto convert_pipeline_scheduler()
}
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Helper function to report unsupported convolution direction with a clear error message.
template <typename Instance>
consteval void report_unsupported_conv_direction_error()
{
throw "Unsupported convolution direction detected!\n"
"The kernel instance does not have a recognized convolution specialization.\n"
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
"kConvBwdWeightSpecialization.\n"
"Please verify that your kernel instance is properly configured.";
}
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
@@ -273,7 +189,10 @@ constexpr builder::ConvDirection conv_direction()
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
return builder::ConvDirection::BACKWARD_WEIGHT;
else
return builder::ConvDirection::FORWARD; // Default fallback
{
report_unsupported_conv_direction_error<Instance>();
return builder::ConvDirection::FORWARD; // Unreachable
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
@@ -296,6 +215,7 @@ constexpr auto conv_spec()
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
case OddC: return ODD_C;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
@@ -334,6 +254,20 @@ template <typename A,
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function is designed to fail at compile time with a descriptive
/// error message when an unsupported layout combination is encountered.
template <typename A, typename B, typename E, int SpatialDim>
consteval void report_unsupported_layout_error()
{
// This will produce a compile-time error with the exception message
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
"Please verify that your convolution instance uses a supported layout configuration.\n"
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
@@ -358,6 +292,8 @@ constexpr auto conv_layout()
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
@@ -368,8 +304,12 @@ constexpr auto conv_layout()
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
@@ -378,6 +318,8 @@ constexpr auto conv_layout()
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
@@ -386,11 +328,31 @@ constexpr auto conv_layout()
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
/// @brief Helper function to report unsupported data type with a clear error message.
template <typename ADataType>
consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
"(BF8), "
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
@@ -401,18 +363,50 @@ constexpr builder::DataType conv_data_type()
if constexpr(std::is_same_v<ADataType, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
return U8;
else
return FP32; // Default fallback
{
report_unsupported_data_type_error<ADataType>();
return FP32; // Unreachable
}
}
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
template <typename ElementwiseOp>
consteval void report_unsupported_elementwise_op_error()
{
throw "Unsupported elementwise operation detected!\n"
"The elementwise operation type is not recognized.\n"
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
"UnaryConvert.\n"
"Please verify that your kernel instance uses a supported elementwise operation.";
}
/// @brief Derives the elementwise operation from op type.
@@ -424,16 +418,83 @@ constexpr builder::ElementwiseOperation elementwise_op()
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
return ADD_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
return ADD_RELU_ADD;
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
return BIAS_BNORM_CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
return BILINEAR;
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Scale"))
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
return CONV_INVSCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
return CONV_SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
return CONV_SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
return CONV_SCALE_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
return SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
return DYNAMIC_UNARY_OP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
return UNARY_COMBINED_OP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
return ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
return ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
return ADD_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
return ADD_ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
return ADD_MUL_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
return UNARY_CONVERT;
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
return LOGISTIC;
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
return CLIPPED_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
return SWISH;
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
return ELU;
else if constexpr(detail::case_insensitive_equal(name, "Power"))
return POWER;
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
return LEAKY_RELU;
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
return UNARY_ABS;
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
return RELU;
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
return SOFT_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
return SIGMOID;
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
return TANH;
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
return GELU;
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
return SILU;
else
{
report_unsupported_elementwise_op_error<ElementwiseOp>();
return PASS_THROUGH; // Unreachable
}
}
/// @brief Derives a gemm padding from a kernel instance type.
@@ -606,45 +667,4 @@ struct ConvTraits<Instance>
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// @file
/// @brief Type definitions for convolution reflection
///
/// This file contains the type definitions used by both conv_traits.hpp and conv_description.hpp
/// to avoid circular dependencies.
#pragma once
#include <array>
namespace ck_tile::reflect::conv {
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
} // namespace ck_tile::reflect::conv

View File

@@ -20,6 +20,11 @@ namespace ck_tile::reflect {
class Description
{
public:
Description() = default;
Description(const Description&) = default;
Description(Description&&) = default;
Description& operator=(const Description&) = default;
Description& operator=(Description&&) = default;
/// @brief Virtual destructor for proper cleanup of derived classes
virtual ~Description() = default;
@@ -36,4 +41,30 @@ class Description
virtual std::string instance_string() const = 0;
};
/// @brief A specialized Description that only supports instance_string()
/// This is a helper class for kernels that don't yet have full ConvDescription support.
/// The brief() and detailed() methods return "not supported" placeholders.
class InstanceStringDescription : public Description
{
public:
/// @brief Construct with an instance string
/// @param instance The instance string to store
explicit InstanceStringDescription(std::string instance) : instance_(std::move(instance)) {}
/// @brief Returns "not supported" as brief descriptions are not implemented
/// @return A placeholder string indicating the feature is not supported
std::string brief() const override { return "not supported"; }
/// @brief Returns "not supported" as detailed descriptions are not implemented
/// @return A placeholder string indicating the feature is not supported
std::string detailed() const override { return "not supported"; }
/// @brief Returns the stored instance string
/// @return The instance string provided during construction
std::string instance_string() const override { return instance_; }
private:
std::string instance_; ///< The stored instance string
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,347 @@
# CK-Builder Testing Utilities
This directory contains testing utilities designed to simplify the process of writing unit tests for GPU kernels built with `ck_tile::builder`. These utilities enable a clean, expressive **Given-When-Then** (Given-When-Then) testing pattern that separates test setup, execution, and validation.
See the [main builder documentation](../README.md) for an overview of the CK-Builder API components.
## Overview
Testing GPU kernels typically involves significant boilerplate: allocating device memory, initializing test data, launching kernels, and validating results. The utilities in this directory abstract away these repetitive tasks, allowing you to focus on defining test cases and verifying correctness.
The core components are:
- **`Args`**: A struct template that holds runtime parameters for a specific test case.
- **`Input`** and **`Output`**: Helper classes that groups operation inputs and outputs.
- **`Validator`**: A utility that performs on-GPU validation and integrates with GoogleTest/GoogleMock.
Together, these components enable a structured approach to kernel testing that mirrors the Given-When-Then pattern commonly used in behavior-driven development.
## The Given-When-Then Testing Pattern
The Given-When-Then pattern organizes tests into three distinct phases:
1. **Given**: Set up the preconditions and test data
2. **When**: Execute the action being tested
3. **Then**: Verify the expected outcome
This structure makes tests easier to read, write, and maintain. Each phase has a clear purpose, and the testing utilities are designed to support this workflow.
### Given: Defining the Test Case
The "Given" phase establishes the context for your test. This includes both the compile-time characteristics of the kernel and the runtime parameters for the specific test case.
#### Operation Signature
The "signature" defines the **mathematical contract** that the kernel must satisfy. It specifies compile-time properties such as:
- Spatial dimensionality (1D, 2D, or 3D)
- Convolution direction (Forward, Backward Data, Backward Weight)
- Tensor memory layout (e.g., NHWC, NCHW)
- Data types (FP32, FP16, BF16, etc.)
- Fused element-wise operations (e.g., Bias, ReLU)
The format of the signature struct is enforced at compile time using C++20 concepts by the CK-Builder API, ensuring type safety and enabling compile-time optimizations. The design of these concepts and the required constraints are discussed in the [CK Builder design description](../include/ck_tile/builder/README.md).
```cpp
// Define our custom signature struct.
struct ConvSignature {
int spatial_dim = 2;
ck_tile::builder::ConvDirection direction =
ck_tile::builder::ConvDirection::FORWARD;
ck_tile::builder::GroupConvLayout2D layout =
ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
ck_tile::builder::DataType data_type =
ck_tile::builder::DataType::FP16;
ck_tile::builder::ElementwiseOperation elementwise_operation =
ck_tile::builder::ElementwiseOperation::NONE;
};
// Double-check that out structure is well-defined according to the CK-Builder API.
static_assert(ck_tile::builder::ConvSignatureDescriptor<ConvSignature>);
// Instantiate the signature with a configuration. These values are again checked
// by the CK-Builder API when a device operation is built.
constexpr auto SIGNATURE = ConvSignature{
.spatial_dim = 2,
.direction = ck_tile::builder::ConvDirection::FORWARD,
.layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = ck_tile::builder::DataType::FP16,
.elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE,
};
```
#### Run-time Arguments
The `Args` struct template provides the **runtime parameters** for your test case. It is parameterized by the `SIGNATURE` and contains fields for tensor dimensions, strides, dilations, and other dynamic properties. Note that the exact parameters required for each `Args` depends on the `SIGNATURE`: For example, a `SIGNATURE` that represents a forward convolution requires specifying the number of batches, groups, input- and output-channels, filter dimensions, filter strides, and so on. A `SIGNATURE` that represents a simple GEMM operation may instead require only the dimensions of the A-, B- and C-matrices.
```cpp
ck_tile::builder::test::Args<SIGNATURE> args = {
.lengths = {
.batch_size = 128,
.groups = 1,
.input_channels = 64,
.output_channels = 128,
.image = {.height = 56, .width = 56},
.filter = {.height = 3, .width = 3},
},
.filter_strides = {.height = 1, .width = 1},
.filter_dilation = {.height = 1, .width = 1},
.input_left_pad = {.width = 1, .height = 1},
.input_right_pad = {.width = 1, .height = 1},
};
```
#### Tensor Memory Management
Tensor memory is passed using the `Inputs<SIGNATURE>` and `Outputs<SIGNATURE>` structures. These group all inputs and outputs for an operation. Note that these structures do not "own" the memory inside: They only logically group the inputs so that they can be passed as a common type. The amount of inputs and outputs may differ depending on the `SIGNATURE`, and this avoids having to pass additional values and accept additional parameters in those situations.
The exact fields in `Inputs` and `Outputs` depend again on the particular `SIGNATURE` that they are constructed with. In general, these structures are intended to be freely constructible from external data and only serve to group relevant information. Automatic memory management can be performed using the `UniqueInputs<SIGNATURE>` and `UniqueOutputs<SIGNATURE>` structures instead. The `alloc_inputs` and `alloc_outputs` functions are used to initialize these types: They take an `Args` structure and allocate the appropriate amounts of memory. `.get()` is used to return an instance of the appropriate `Input` or `Output`.
```cpp
auto inputs = ck_tile::builder::test::allocate_inputs(args);
auto outputs = ck_tile::builder::test::allocate_outputs(args);
```
Note that these functions merely _allocate_ memory: After allocation, the memory is still uninitialized.
#### Tensor Memory Initialization
Operation inputs can be initialized by using `ck_tile::builder::test::init_inputs()`. Crucially, this operation accepts _all_ inputs, as well as the `args` structure. This is because initializing tensor memory is a context-dependent operation: We need to understand the operation in detail in order to generate inputs which do not overflow, do not generate NaNs or all zeros, etc. Passing the `args` allows `init_inputs` to generate a good test for the operation at hand.
### When: Executing the Kernel
The "When" phase is where the kernel to be tested is actually executed. This involves selecting an algorithm and using the `Builder` to generate the kernel.
#### Operation Algorithm
The "algorithm" defines the **implementation strategy** for the kernel. It specifies low-level details such as:
- Thread block dimensions and tile sizes
- GEMM implementation (XDL or WMMA)
- Data transfer vectorization
- Pipeline scheduling
As with the signature struct, the format of the algorithm struct is enforced at compile time using C++20 concepts by the CK-Builder API. The design of these concepts and the required constraints are discussed in the [CK Builder factory design description](../include/ck_tile/builder/factory/README.md).
```cpp
// Define our custom algorithm struct.
struct ConvAlgorithm {
// Thread block configuration
ThreadBlock thread_block;
// Gridwise GEMM configuration
GridwiseXdlGemm gridwise_gemm;
// Block transfer configuration
Transfer transfer;
// Additional tuning parameters
// ...
};
// Double-check that our algorithm is well-defined according to the CK-Builder API.
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ConvAlgorithm>);
// Instantiate the algorithm with a configuration. Like with the signature struct
// the CK-Builder API will check that the values are correct when a device
// operation is built.
constexpr auto ALGORITHM = ConvAlgorithm{
.thread_block = /* ... */;
.gridwise_gem = /* ... */;
.transfer = /* ... */;
// ...
};
```
#### Building the Kernel
The `Builder` combines the signature (what to compute) with the algorithm (how to compute it) to generate a kernel type which represents the operation. The implementation details, including invocation method, depend on the particular signature and algorithm.
```cpp
using Conv = ck_tile::builder::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
auto conv = Conv{};
```
#### Invoking the Kernel
After creating the kernel instance, it can be invoked by passing the instance, the arguments, the inputs, and the outputs to `run()`. This operation writes results into the buffers in `outputs`.
```cpp
ck_tile::builder::test::run(conv, args, inputs.get(), outputs.get());
```
### Then: Verifying the Results
The "Then" phase validates that the kernel produced the expected output. This is done by running a reference kernel and comparing the results.
#### Building the Reference Kernel
The reference kernel is just another kernel instance of the builder, one that's been externally verified to produce the correct results. As this kernel is also running on the GPU, we can use it to perform tests far more quickly than when comparing the outputs to a CPU-based reference implementation.
In order to obtain an instance of the reference kernel, the correct `ALGORITHM` needs to be passed to the `Builder`.
```cpp
struct ReferenceAlgorithm {
ck_tile::builder::ConvAlgorithmSpecialization specialization;
};
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ReferenceAlgorithm>);
constexpr auto REFERENCE_ALGORITHM = ReferenceAlgorithm{
.specialization = ck_tile::builder::ConvAlgorithmSpecialization::REFERENCE;
};
using ReferenceConv = ck_tile::builder::ConvBuilder<SIGNATURE, REFERENCE_ALGORITHM>::Instance;
auto reference_conv = ReferenceConv{};
```
This instance can then be invoked using `ck_tile::builder::test::run()`, the same as the kernel to be tested. Note that another instance of the `Outputs` structure needs to be passed here in order to store the results.
```cpp
auto reference_outputs = ck_tile::builder::test::allocate_outputs(args);
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
```
#### `Validator<SIGNATURE>`
The `Validator` class encapsulates the validation logic. It performs on-GPU correctness checks by comparing two instances of the `Outputs` structure.
```cpp
ck_tile::builder::test::Validator<SIGNATURE> validator(outputs.get(), reference_outputs.get());
```
The `Validator` provides methods that return GoogleMock matchers, enabling clean integration with GoogleTest:
```cpp
EXPECT_THAT(validator.result(), validator.matches_reference_output());
```
The `matches_reference_output()` matcher checks that the output is numerically correct within acceptable tolerances. The `Validator` can also provide more detailed diagnostics, such as:
- Maximum absolute error
- Maximum relative error
- Number of mismatched elements
- Specific locations of errors
## Complete Example
Here's a complete test that demonstrates the Given-When-Then pattern:
```cpp
#include <gtest/gtest.h>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/testing/tensor_memory_manager.hpp"
#include "ck_tile/testing/validator.hpp"
// Define the convolution signature
struct ConvSignature {
int spatial_dim = 2;
ck_tile::builder::ConvDirection direction =
ck_tile::builder::ConvDirection::FORWARD;
ck_tile::builder::GroupConvLayout2D layout =
ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
ck_tile::builder::DataType data_type =
ck_tile::builder::DataType::FP16;
ck_tile::builder::ElementwiseOperation elementwise_operation =
ck_tile::builder::ElementwiseOperation::NONE;
};
static_assert(ck_tile::builder::ConvSignatureDescriptor<ConvSignature>);
constexpr auto SIGNATURE = ConvSignature{
.spatial_dim = 2,
.direction = ck_tile::builder::ConvDirection::FORWARD,
.layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = ck_tile::builder::DataType::FP16,
.elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE,
};
// Define the convolution algorithm
struct ConvAlgorithm {
// Algorithm configuration details...
// (Omitted for brevity)
};
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ConvAlgorithm>);
constexpr auto ALGORITHM = ConvAlgorithm{/* ... */};
// Define the reference convolution algorithm
struct ReferenceAlgorithm {
ck_tile::builder::ConvAlgorithmSpecialization specialization;
};
static_assert(ck_tile::builder::ConvAlgorithmDescriptor<ReferenceAlgorithm>);
constexpr auto REFERENCE_ALGORITHM = ReferenceAlgorithm{
.specialization = ck_tile::builder::ConvAlgorithmSpecialization::REFERENCE;
};
// The actual test
TEST(ConvolutionTest, Forward2D_FP16) {
// ===== GIVEN: Set up the test case =====
// Define runtime parameters
ck_tile::builder::test::Args<ConvSignature> args = {
.lengths = {
.batch_size = 128,
.groups = 1,
.input_channels = 64,
.output_channels = 128,
.image = {.height = 56, .width = 56},
.filter = {.height = 3, .width = 3},
},
.filter_strides = {.height = 1, .width = 1},
.filter_dilation = {.height = 1, .width = 1},
.input_left_pad = {.width = 1, .height = 1},
.input_right_pad = {.width = 1, .height = 1},
};
// Allocate GPU memory
auto inputs = ck_tile::builder::test::allocate_inputs(args);
auto outputs = ck_tile::builder::test::allocate_outputs(args);
auto reference_outputs = ck_tile::builder::test::allocate_outputs(args);
// Initialize inputs
ck_tile::builder::test::init_inputs(args, inputs);
// ===== WHEN: Execute the kernel =====
// Build the kernel
using Conv = ck_tile::builder::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
auto conv = Conv{};
// Compute actual results
ck_tile::builder::test::run(conv, args, inputs.get(), outputs.get());
// ===== THEN: Verify the results =====
// Build the reference kernel
using ReferenceConv = ck_tile::builder::ConvBuilder<SIGNATURE, REFERENCE_ALGORITHM>::Instance;
auto reference_conv = ReferenceConv{};
// Compute reference results
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
// Check the results
ck_tile::builder::test::Validator<SIGNATURE> validator(outputs.get(), reference_outputs.get());
EXPECT_THAT(validator.result(), validator.is_ok());
}
```
## Benefits of This Approach
1. **Clarity**: The Given-When-Then structure makes tests self-documenting. Each phase has a clear purpose.
2. **Reduced Boilerplate**: The utilities handle memory management, initialization, and validation, eliminating repetitive code.
3. **Type Safety**: The use of C++20 concepts ensures that signatures and algorithms are well-formed at compile time.
4. **Flexibility**: The `Args` struct can be easily extended to support different test scenarios, `Inputs` and `Outputs` can be modified to support additional tensors where necessary, and alternatives to `init_inputs()` can be provided to support additional testing strategies.
5. **Integration**: The `Validator` integrates seamlessly with GoogleTest/GoogleMock, providing familiar assertion syntax.
6. **Maintainability**: Changes to the testing infrastructure are localized to the utility classes, not scattered across individual tests.
## Future Enhancements
Potential improvements to the testing utilities include:
- Performance benchmarking utilities
- Automatic test case generation from parameter ranges
- Enhanced error reporting with visual diffs
- Support for multi-GPU testing scenarios

View File

@@ -0,0 +1,256 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/extent.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
/// This file implements common functionality for invoking/testing grouped
/// forward convolutions created through the CK Builder API. The main item
/// of it is the ConvArgs structure - which contains a complete description
/// of a convolution operation.
///
/// It is not intended that this file contains implementation details for
/// actually launching a convolution operation. As this can be done
/// through different APIs depending on the kernel (CK, CK Tile, or a
/// reference implementation), the code dealing with that is split out
/// into a separate header for each implementation.
namespace ck_tile::builder::test {
/// @brief Convolution tensor dimensions.
///
/// This structure is used to describe lengths of a convolution problem. In
/// fact, this structure is a complete description of ALL inputs and outputs
/// lengths of a convolution problem, as this structure contains all of the
/// combined parameters. Note that we can't also use this structure to describe
/// tensor strides: whereas the lengths are all governed by a common set of
/// parameters, strides of the input, weight, and output tensor are all
/// independent.
template <int SPATIAL_DIM>
struct ConvTensorLengths
{
size_t batch_size = 1; // N
size_t groups = 1; // G
size_t input_channels = 1; // C
size_t output_channels = 1; // K
Extent<SPATIAL_DIM> image = {}; // W, H, D
Extent<SPATIAL_DIM> filter = {}; // X, Y, Z
};
/// @brief `Args` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Args
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Args<SIGNATURE>
{
constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim;
constexpr static auto INPUT_TYPE = SIGNATURE.data_type;
constexpr static auto WEIGHT_TYPE = SIGNATURE.data_type;
constexpr static auto OUTPUT_TYPE = SIGNATURE.data_type;
// TODO: We shouldn't need to call into an internal namespace here.
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts =
factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
ConvTensorLengths<SPATIAL_DIM> lengths;
// TODO: Tensor strides. This needs a new structure as well as some
// reworking of the make_*_descriptor() functions, as the current
// implementation (based on ConvParam in old CK / CK Tile) does not
// support strides at all.
Extent<SPATIAL_DIM> filter_strides;
Extent<SPATIAL_DIM> filter_dilation;
Extent<SPATIAL_DIM> input_left_pad;
Extent<SPATIAL_DIM> input_right_pad;
Ops::AElementwiseOp a_elementwise_op;
Ops::BElementwiseOp b_elementwise_op;
Ops::CDEElementwiseOp cde_elementwise_op;
/// This function returns the `TensorDescriptor` corresponding to
/// the input-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<INPUT_TYPE> make_input_descriptor() const
{
// TODO: We're using old CK functionality to compute the right
// values here, mainly because CK tile does not support the
// right tensor layouts here. We should probably change that
// because CK currently prints an annoying message about it,
// plus that would let us get rid of the `to_ck_conv_param()`
// function.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
typename Layouts::ALayout>(param);
return TensorDescriptor<INPUT_TYPE>(desc.GetLengths(), desc.GetStrides());
}
/// This function returns the `TensorDescriptor` corresponding to
/// the weight-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<WEIGHT_TYPE> make_weight_descriptor() const
{
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
typename Layouts::BLayout>(param);
return TensorDescriptor<WEIGHT_TYPE>(desc.GetLengths(), desc.GetStrides());
}
/// This function returns the `TensorDescriptor` corresponding to
/// the output-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<OUTPUT_TYPE> make_output_descriptor() const
{
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
typename Layouts::ELayout>(param);
return TensorDescriptor<OUTPUT_TYPE>(desc.GetLengths(), desc.GetStrides());
}
/// Convert the Args structure into a CK conv_param structure. This
/// function is mainly used to be able to use the existing
/// CK-functionality to obtain tensor descriptors.
ck::utils::conv::ConvParam to_ck_conv_param() const
{
const auto to_vector = [](const auto& extent) {
if constexpr(SPATIAL_DIM == 1)
return std::vector<ck::index_t>{ck::index_t(extent.width)};
else if constexpr(SPATIAL_DIM == 2)
return std::vector<ck::index_t>{ck::index_t(extent.height),
ck::index_t(extent.width)};
else
return std::vector<ck::index_t>{ck::index_t(extent.depth),
ck::index_t(extent.height),
ck::index_t(extent.width)};
};
return ck::utils::conv::ConvParam(SPATIAL_DIM,
this->lengths.groups,
this->lengths.batch_size,
this->lengths.output_channels,
this->lengths.input_channels,
to_vector(this->lengths.filter),
to_vector(this->lengths.image),
to_vector(this->filter_strides),
to_vector(this->filter_dilation),
to_vector(this->input_left_pad),
to_vector(this->input_right_pad));
}
};
/// @brief `Inputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* weight;
};
/// @brief `Outputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* output;
};
/// @brief `UniqueInputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see UniqueInputs
/// @see ValidUniqueInputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct UniqueInputs<SIGNATURE>
{
DeviceBuffer input_buf;
DeviceBuffer weight_buf;
/// @see ValidUniqueInputs
Inputs<SIGNATURE> get()
{
return {
.input = input_buf.get(),
.weight = weight_buf.get(),
};
}
};
/// @brief `UniqueOutputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see UniqueOutputs
/// @see ValidUniqueOutputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct UniqueOutputs<SIGNATURE>
{
DeviceBuffer output_buf;
/// @see ValidUniqueOutputs
Outputs<SIGNATURE> get()
{
return {
.output = output_buf.get(),
};
}
};
/// @brief `alloc_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
ValidUniqueInputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
{
return {
.input_buf = alloc_tensor_buffer(args.make_input_descriptor()),
.weight_buf = alloc_tensor_buffer(args.make_weight_descriptor()),
};
}
/// @brief `alloc_outputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_outputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
ValidUniqueOutputs<SIGNATURE>
UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
{
return {
.output_buf = alloc_tensor_buffer(args.make_output_descriptor()),
};
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,102 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <span>
#include <cstddef>
#include "ck_tile/builder/testing/conv_fwd.hpp"
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations in old CK. The main item is the
/// `run()` function, which is the main implementation used to invoke
/// CK grouped forward convolution kernels.
namespace ck_tile::builder::test {
/// @brief Concept for checking whether a convolution is invoked like old CK.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be an "old CK" implementation - that is, whether we should
/// invoke it as an old CK kernel. This is mainly used with `run()` to
/// differentiate which implementation that should be invoked.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <auto SIGNATURE, typename Conv>
concept IsCkConvInstance =
// TODO: This should be implemented by converting the signature into the
// type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave
// it empty. Improve when needed, you get the point. Also we should probably
// move this to the ck conv factory helper.
true;
/// @brief `run()` specialization for forward convolution and old CK.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments werent actually valid for the
/// operation. This should be caught and reported by the testing framework.
///
/// @see run()
template <auto SIGNATURE, typename Conv>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
IsCkConvInstance<SIGNATURE, Conv>
void run(Conv& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
const auto copy = [](const auto& src, auto& dst) {
std::copy(src.begin(), src.end(), dst.begin());
};
const auto to_ck_lengths = [&](const auto& src) {
std::array<ck::index_t, spatial_dim + 3> result;
copy(src, result);
return result;
};
const auto to_ck_extent = [&](const auto& extent) {
std::array<ck::index_t, spatial_dim> result;
copy(extent, result);
return result;
};
const auto param = args.to_ck_conv_param();
const auto input_desc = args.make_input_descriptor();
const auto weight_desc = args.make_weight_descriptor();
const auto output_desc = args.make_output_descriptor();
auto ck_args = conv.MakeArgument(inputs.input,
inputs.weight,
{},
outputs.output,
to_ck_lengths(input_desc.get_lengths()),
to_ck_lengths(input_desc.get_strides()),
to_ck_lengths(weight_desc.get_lengths()),
to_ck_lengths(weight_desc.get_strides()),
{},
{},
to_ck_lengths(output_desc.get_lengths()),
to_ck_lengths(output_desc.get_strides()),
to_ck_extent(param.conv_filter_strides_),
to_ck_extent(param.conv_filter_dilations_),
to_ck_extent(param.input_left_pads_),
to_ck_extent(param.input_right_pads_),
args.a_elementwise_op,
args.b_elementwise_op,
args.cde_elementwise_op);
if(!conv.IsSupportedArgument(ck_args))
{
throw std::runtime_error("invalid argument");
}
conv.MakeInvoker().Run(ck_args, {});
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,36 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::builder::test {
/// This structure describes a 1-, 2-, or 3-D extent. Its used to
/// communicate 1-, 2- or 3-D sizes and strides of tensors.
/// Depending on the dimension, the structure will have the `width`,
/// `height`, and `depth` fields available.
template <int SPATIAL_DIM>
struct Extent;
template <>
struct Extent<1>
{
size_t width = 1;
};
template <>
struct Extent<2>
{
size_t width = 1;
size_t height = 1;
};
template <>
struct Extent<3>
{
size_t width = 1;
size_t height = 1;
size_t depth = 1;
};
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,212 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <stdexcept>
#include <memory>
#include <numeric>
#include <span>
#include <concepts>
#include <hip/hip_runtime.h>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
#include "ck_tile/host/host_tensor.hpp"
/// This file deals with tensor memory allocation: Both the act of allocating
/// and (automatically) deallocating memory, as well as utilities for managing
/// the layout of tensor data in memory.
namespace ck_tile::builder::test {
/// @brief Automatic deleter for GPU memory.
///
/// This structure implements a C++ functor which can be used to configure
/// `std::unique_ptr` to automatically delete memory using `hipFree`.
///
/// @see DeviceBuffer
struct DeviceMemoryDeleter
{
/// @brief Deleter callback.
///
/// This function is invoked by `std::unique_ptr` when memory that the
/// pointer represents should be freed. In our implementation, we just
/// pass it directly to `hipFree`.
void operator()(std::byte* ptr) const
{
if(ptr)
(void)hipFree(ptr);
}
};
/// @brief HIP out of memory error
///
/// This is a derivation of `std::runtime_error` specialized for HIP
/// out-of-memory errors.
///
/// @see std::runtime_error
struct OutOfDeviceMemoryError : std::runtime_error
{
/// @brief Utility for formatting out-of-memory error messages
///
/// Returns a human-readable description of a HIP out-of-memory error.
///
/// @param status The status to report
static std::string format_error(hipError_t status)
{
return std::string("failed to allocate hip memory: ") + hipGetErrorString(status) + " (" +
std::to_string(status) + ")";
}
/// @brief Construct an out-of-memory error using `status` as message.
///
/// @param status A HIP error status that was encountered while allocating memory.
OutOfDeviceMemoryError(hipError_t status) : std::runtime_error(format_error(status)) {}
};
/// @brief Automatically managed GPU memory.
///
/// The `DeviceBuffer` is an automatically managed pointer for GPU memory. When
/// adopting a device pointer into a `DeviceBuffer`, it will automatically be
/// free'd when the pointer goes out of scope. Memory can be allocated directly
/// into a `DeviceBuffer` using `alloc_buffer()` or `alloc_tensor_buffer()`.
///
/// Since this type is just an alias of `std::unique_ptr`, you can use that type's
/// functionality to manage memory further, such as `.reset()` to release the
/// memory.
///
/// @see alloc_buffer()
/// @see alloc_tensor_buffer()
using DeviceBuffer = std::unique_ptr<std::byte[], DeviceMemoryDeleter>;
/// @brief Allocate automatically managed GPU memory.
///
/// This function essentially acts like a managed version of hipMalloc -
/// allocating GPU memory on the currently active device - except that this
/// version returns an automatically managed pointer.
///
/// @param size The amount of memory to allocate in bytes.
/// @throws OutOfDeviceMemoryError if memory allocation failed.
///
/// @see DeviceBuffer
/// @see OutOfDeviceMemoryError
/// @see hipMalloc()
inline DeviceBuffer alloc_buffer(size_t size)
{
std::byte* d_buf = nullptr;
if(const auto status = hipMalloc(&d_buf, size); status != hipSuccess)
{
throw OutOfDeviceMemoryError(status);
}
return DeviceBuffer(d_buf);
}
/// @brief Type managing tensor data layout in memory.
///
/// This structure describes a tensor in memory. It does not actually hold any
/// reference to memory, it just describes how the memory should be laid out if it
/// were.
///
/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it
/// also includes the data type of the elements of htis tensor. This is mainly to
/// make the descriptor a _complete_ description of a tensor rather than just the
/// dimensions in strides, which helps in reducing clutter in uses of this type.
///
/// @note All strides are still in _elements_.
///
/// @tparam DT The conceptual data type of the tensor elements. This need not be the
/// type that the data is actually stored as in memory.
template <DataType DT>
struct TensorDescriptor
{
// For now, the implementation of this type is based on
// `ck_tile::HostTensorDescriptor`, so that we can prototype without
// reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard
// the use of `ck_tile::HostTensorDescriptor` here as an implementation detail.
/// The conceptual data type of the tensor elements. This need not be the type
/// that the data is actually stored as in memory.
constexpr static DataType data_type = DT;
/// @brief Create a tensor descriptor from lengths and strides.
///
/// @param lengths A sequence of tensor lengths, the conceptial dimensions of
/// the tensor in elements.
/// @param strides A sequence of in-memory strides of the tensor, measured in
/// elements. Each element of `strides`` corresponds to one at the same index
/// in `lengths`, the amount of elements to skip in memory to find the next
/// element along that axis.
TensorDescriptor(std::span<const size_t> lengths, std::span<const size_t> strides)
: inner_descriptor_(lengths, strides)
{
// TODO: Validation of strides? For now we just delegate the details of the
// construction to the CK Tile HostTensorDescriptor.
}
/// Query the conceptual dimensions of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Note that the order
/// does *not* correspond with memory layout, query the in-memory strides for
/// that.
///
/// @see get_strides()
std::span<const size_t> get_lengths() const { return inner_descriptor_.get_lengths(); }
/// Query the in-memory strides of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Each element
/// corresponds directly with the stride in elements at the same index in the
/// tensor dimensions.
///
/// @see get_lengths()
std::span<const size_t> get_strides() const { return inner_descriptor_.get_strides(); }
/// @brief Compute total tensor size in elements.
///
/// This function returns the total size of the memory backing a tensor with
/// this descriptor in *elements*, including required extra size for strides.
///
/// @see get_element_space_size_in_bytes()
size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); }
/// @brief Compute total tensor size in bytes.
///
/// This function is like `get_element_space_size()`, except that the returned
/// value is measured in *bytes* rather than *elements*. Use this function for
/// figuring out how much memory needs to be allocated for a particular tensor.
///
/// @see get_element_space_size()
size_t get_element_space_size_in_bytes() const
{
// For now, the backing type is the naive C++-type that represents the data
// type. When we are going to support packed types such as i4 and fp6, this
// is going to become more complicated.
return get_element_space_size() * data_type_sizeof(DT);
}
private:
ck_tile::HostTensorDescriptor inner_descriptor_;
};
/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor.
///
/// This function is similar to `alloc_buffer()`, except that the required size is
/// derived automatically from a tensor descriptor. The returned buffer is valid for
/// tensors with that layout. Strides are also taken into account when computing the
/// required size.
///
/// @tparam DT The conceptual datatype of the elements of the tensor.
/// @param descriptor A descriptor of the memory layout of the tensor to allocate.
/// @throws OutOfDeviceMemoryError if memory allocation failed.
///
/// @see TensorDescriptor
/// @see DeviceBuffer
/// @see OutOfDeviceMemoryError
/// @see hipMalloc()
template <DataType DT>
DeviceBuffer alloc_tensor_buffer(const TensorDescriptor<DT>& descriptor)
{
return alloc_buffer(descriptor.get_element_space_size_in_bytes());
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,260 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
/// This file is the main header for the CK-Builder testing system. A high-level
/// description of this testing system is documented in
/// `ck_tile/builder/testing/README.md`. This file deals mainly deals with the
/// documentation of the implementation details by forward-declaring and documenting
/// the relevant types.
///
/// The intention is that the basic testing strategy (explained in the testing
/// documentation) is available for every different type of device operation. This
/// requires us to provide some implementations in two fronts: Support for the
/// Args, Inputs, Outputs, UniqueInputs, and UniqueOutputs for all SIGNATUREs which
/// are supported by CK Builder, and support for invoking the different
/// implementations returned by CK Builder, depending on the Algorithm.
///
/// Different SIGNATUREs may require different arguments and different (amounts of)
/// input/output tensors. Rather than trying to cram all this in the same structure,
/// or to provide different types, we will use dependent typing to specialize the
/// implementation for the SIGNATURE at hand. For this reason, the Args, Inputs,
/// Outputs, UniqueInputs, and UniqueOutputs structures are all parameterized by the
/// SIGNATURE. The idea is to use C++20 concepts to limit the specialization to the
/// subset of SIGNATUREs that conceptually make sense for that implementation. For
/// example, to provide an implementation of the testing framework for forward
/// convolutions, we can use a concept to check whether the SIGNATURE is a valid
/// forward convolution signature:
///
/// template <auto SIGNATURE>
/// requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
/// struct Args<SIGNATURE> { ... }; // Similar for the other types
///
/// Invocation of instances is another matter: The Builder may return instances from
/// either CK or CK-Tile depending on the ALGORITHM configuration. The only place
/// where this matters is the implementation of `run()`, which needs to provide a
/// custom implementation for all instances which the Builder may return, including
/// the reference implementation. The strategy is the same here: Use concepts to
/// check whether the instance returned by the builder is of a particular type, and
/// overload the `run()` function for that concept:
///
/// template <auto SIGNATURE, typename Conv>
/// requires
/// // Check that the SIGNATURE is of the type that we expect
/// ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
/// // Also check that the instance is of a type which we can invoke here
/// IsCkConvInstance<SIGNATURE, Conv>
/// void run(Conv& conv, ...);
///
/// Note that this is only the suggested strategy; you may also use `if constexpr`
/// or similar to dispatch the correct implementation of the instance in the
/// implementation of the `run()` function for a particular group of device
/// operations.
///
/// The remainder of this file describes the types and functions that should be
/// overloaded for a particular device operation, and in which situation.
namespace ck_tile::builder::test {
/// @brief Run-time arguments corresponding to a signature.
///
/// The `Args` structure is the main point of runtime configuration for a device
/// operation. Depending on the SIGNATURE, it is used to provide the run-time
/// parameters for a device operation, for instance, for the tensor dimensions,
/// tensor strides, parameters such as padding, split-K batch size, fused
/// element-wise operator instances, etc. In short, a complete run-time
/// configuration of the tensor operation at hand.
///
/// This structure does not require additional member functions, any which are
/// provided should be considered implementation details of Args structure for
/// that particular SIGNATURE.
///
/// @note A good indicator of the fields necessary here are the values that should
/// be passed to the CK `MakeArgument()` function or CK-Tile `HostArgs` structure
/// of the device operation that you are trying to implement. It is the intention
/// that this structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
template <auto SIGNATURE>
struct Args;
/// @brief Non-owning input collection corresponding to a signature.
///
/// The `Input` structure represents the collection of input tensor data on the
/// device, associated to a particular SIGNATURE. The exact fields in this structure
/// may again depend on the exact SIGNATURE. This structure is non-owning: its use
/// is intended as a way to pass all inputs around as a single value.
///
/// This structure does not require additional member functions, any which are
/// provided should be considered implementation details of Args structure for
/// that particular SIGNATURE.
///
/// @note The implementation can just be a set of void-pointers which conceptually
/// represent the inputs of the device operation. It is the intention that this
/// structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
template <auto SIGNATURE>
struct Inputs;
/// @brief Non-owning outputs collection corresponding to a signature.
///
/// The `Output` structure represents the collection of input tensor data on the
/// device, associated to a particular SIGNATURE. The exact fields in this structure
/// may again depend on the exact SIGNATURE. This structure is non-owning: its use
/// is intended as a way to pass all outputs around as a single value.
///
/// This structure does not require additional member functions, any which are
/// provided should be considered implementation details of Args structure for
/// that particular SIGNATURE.
///
/// @note The implementation can just be a set of void-pointers which conceptually
/// represent the outputs of the device operation. It is the intention that this
/// structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
template <auto SIGNATURE>
struct Outputs;
/// @brief RAII-enabled inputs collection corresponding to a signature.
///
/// The `UniqueInputs` is used to automatically manage the memory of a set of
/// inputs. Unlike the corresponding `Inputs` structure, the implementation is
/// opaque; the only requirements for this structure is that an instance can
/// be created using `alloc_inputs()` and that an instance of the corresponding
/// `Inputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each input tensor.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @see alloc_inputs()
/// @see ValidUniqueInputs
/// @see DeviceBuffer
template <auto SIGNATURE>
struct UniqueInputs;
/// @brief RAII-enabled outputs collection corresponding to a signature.
///
/// The `UniqueOutputs` is used to automatically manage the memory of a set of
/// outputs. Unlike the corresponding `Outputs` structure, the implementation is
/// opaque; the only requirements for this structure is that an instance can
/// be created using `alloc_outputs()` and that an instance of the corresponding
/// `Outputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each output tensor.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @see alloc_outputs()
/// @see ValidUniqueOutputs
/// @see DeviceBuffer
template <auto SIGNATURE>
struct UniqueOutputs;
/// @brief Concept to check the validity of `UniqueInputs`.
///
/// The `ValidUniqueInputs` concept can be used to check whether the definition
/// of `UniqueInputs` is valid for a particular SIGNATURE.
///
/// - SIGNATURE is signature to specialize the structure for.
///
/// @see UniqueInputs
template <auto SIGNATURE>
concept ValidUniqueInputs = requires(UniqueInputs<SIGNATURE>& inputs) {
/// `.get()` is used to obtain a non-owning version of the `Inputs` collection.
{ inputs.get() } -> std::convertible_to<Inputs<SIGNATURE>>;
};
/// @brief Concept to check the validity of `UniqueOutputs`.
///
/// The `ValidUniqueOutputs` concept can be used to check whether the definition
/// of `UniqueOutputs` is valid for a particular SIGNATURE.
///
/// - SIGNATURE is signature to specialize the structure for.
///
/// @see UniqueOutputs
template <auto SIGNATURE>
concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// `.get()` is used to obtain a non-owning version of the `Outputs` collection.
{ inputs.get() } -> std::convertible_to<Outputs<SIGNATURE>>;
};
/// @brief Allocate inputs corresponding to a signature.
///
/// The `alloc_inputs()` function is used to create an instance of
/// `UniqueInputs`. This function uses the `args` structure to compute the
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @see Inputs
/// @see UniqueInputs
/// @see alloc_buffer()
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueInputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
/// @brief Allocate outputs corresponding to a signature.
///
/// The `alloc_outputs()` function is used to create an instance of
/// `UniqueOutputs`. This function uses the `args` structure to compute the
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @see Outputs
/// @see UniqueOutputs
/// @see alloc_buffer()
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueOutputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
/// @brief Invoke a device operation created by CK Builder.
///
/// This is the main function used to invoke a particular device operation
/// instance created by the builder. It uses the `args`, `inputs`, and `outputs`
/// to configure the `operation` and invokes it immediately.
///
/// In practice, the `Operation` is usually a CK or CK Tile device operation
/// type, for example `DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`.
/// This function implements the required functionality to invoke any relevant
/// type created by the builder.
///
/// @note Unlike the Args, Inputs, Outputs, and related structures, this function
/// is specialized for the different implementations that the builder may
/// return (see file-level documentation).
///
/// @pre The tensors in `inputs` should be allocated and initialized with the
/// appropriate values to perform the operation.
/// @pre The tensors in `outputs` should be allocated.
/// @post The tensors in `outputs` are overwritten with the outputs of the device
/// operation.
///
/// @tparam SIGNATURE the signature to specialize this function for
/// @tparam Operation the kernel of the operation to invoke. This type should be
/// one that is created using the Builder API.
/// @param operation An instance of the operation to invoke.
/// @param args The run-time arguments of the operation.
/// @param inputs The input tensor data. Will not be modified by this function.
/// @param outputs The output tensor data. The contents will be overwritten by
/// this function.
template <auto SIGNATURE, typename Operation>
void run(Operation& operation,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstddef>
#include "ck_tile/builder/types.hpp"
/// This file implements various backend-independent traits for
/// CK-Builder types.
namespace ck_tile::builder::test {
/// @brief Query the size of a data type in memory.
///
/// This function computes the size of a variant of `DataType` in memory.
/// This is more complicated than it seems. For most types, this is just
/// the size of the equivalent C++-type, but for sub-byte type we have to
/// represent each byte by multiple values, for example. For now, we only
/// care about types which consist of an integral number of bytes, though.
///
/// @note The details of this function are likely going to change with the
/// support of sub-byte types.
///
/// @param data_type The type to query the in-memory size of.
/// @returns The number of bytes that an element of this data type requires
/// in memory.
constexpr size_t data_type_sizeof(DataType data_type)
{
switch(data_type)
{
case DataType::UNDEFINED_DATA_TYPE: return 0;
case DataType::FP32: return 4;
case DataType::FP32_FP32: return 8;
case DataType::FP16: return 2;
case DataType::FP16_FP16: return 4;
case DataType::BF16: return 2;
case DataType::BF16_BF16: return 4;
case DataType::FP8: return 1;
case DataType::BF8: return 1;
case DataType::FP64: return 8;
case DataType::INT32: return 4;
case DataType::I8: return 1;
case DataType::I8_I8: return 2;
case DataType::U8: return 1;
}
return 0; // Default case to ensure all control paths return a value
}
} // namespace ck_tile::builder::test

View File

@@ -11,15 +11,22 @@
namespace ck_tile::builder {
// TODO: Handle tuple types and FP8/BF8 properly
enum class DataType
{
UNDEFINED_DATA_TYPE = 0,
FP32,
FP32_FP32,
FP16,
FP16_FP16,
BF16,
BF16_BF16,
FP8,
BF8,
FP64,
INT32,
I8,
I8_I8,
U8
};
@@ -102,13 +109,44 @@ enum class ConvDirection
};
// Fused element-wise operations.
// TODO: Generalize design rather than enumerating all possible ops.
enum class ElementwiseOperation
{
ADD_CLAMP,
ADD_RELU_ADD,
ACTIVATION_MUL2_CLAMP,
ACTIVATION_MUL_CLAMP,
ADD_ACTIVATION_MUL_CLAMP,
ADD_ACTIVATION_MUL2_CLAMP,
ADD_MUL_ACTIVATION_MUL_CLAMP,
ADD_MUL2_ACTIVATION_MUL_CLAMP,
BIAS_BNORM_CLAMP,
BILINEAR,
SCALE,
SCALE_ADD,
CLAMP,
CONV_INVSCALE,
CONV_SCALE,
CONV_SCALE_ADD,
CONV_SCALE_RELU,
PASS_THROUGH,
SCALEADD_SCALEADD_RELU
SCALEADD_SCALEADD_RELU,
DYNAMIC_UNARY_OP,
UNARY_COMBINED_OP,
UNARY_CONVERT,
LOGISTIC,
CLIPPED_RELU,
SWISH,
ELU,
POWER,
LEAKY_RELU,
UNARY_ABS,
RELU,
SOFT_RELU,
SIGMOID,
TANH,
GELU,
SILU
};
// Enums for pipeline versions & schedulers
@@ -160,7 +198,8 @@ enum class ConvFwdSpecialization
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
FILTER_3x3,
ODD_C
};
// Enums for the backward data convolution specialization.
@@ -219,11 +258,17 @@ inline std::string_view toString(DataType dt)
switch(dt)
{
case FP16: return "FP16";
case FP16_FP16: return "FP16_FP16";
case FP32: return "FP32";
case FP32_FP32: return "FP32_FP32";
case BF16: return "BF16";
case BF16_BF16: return "BF16_BF16";
case FP8: return "FP8";
case BF8: return "BF8";
case FP64: return "FP64";
case INT32: return "INT32";
case I8: return "I8";
case I8_I8: return "I8_I8";
case U8: return "U8";
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
default: return "Unknown";
@@ -247,11 +292,41 @@ inline std::string_view toString(ElementwiseOperation op)
using enum ElementwiseOperation;
switch(op)
{
case ADD_CLAMP: return "ADD_CLAMP";
case ADD_RELU_ADD: return "ADD_RELU_ADD";
case ACTIVATION_MUL2_CLAMP: return "ACTIVATION_MUL2_CLAMP";
case ACTIVATION_MUL_CLAMP: return "ACTIVATION_MUL_CLAMP";
case ADD_ACTIVATION_MUL_CLAMP: return "ADD_ACTIVATION_MUL_CLAMP";
case ADD_ACTIVATION_MUL2_CLAMP: return "ADD_ACTIVATION_MUL2_CLAMP";
case ADD_MUL_ACTIVATION_MUL_CLAMP: return "ADD_MUL_ACTIVATION_MUL_CLAMP";
case ADD_MUL2_ACTIVATION_MUL_CLAMP: return "ADD_MUL2_ACTIVATION_MUL_CLAMP";
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
case BILINEAR: return "BILINEAR";
case CLAMP: return "CLAMP";
case SCALE: return "SCALE";
case SCALE_ADD: return "SCALE_ADD";
case CONV_INVSCALE: return "CONV_INVSCALE";
case CONV_SCALE: return "CONV_SCALE";
case CONV_SCALE_ADD: return "CONV_SCALE_ADD";
case CONV_SCALE_RELU: return "CONV_SCALE_RELU";
case PASS_THROUGH: return "PASS_THROUGH";
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
case DYNAMIC_UNARY_OP: return "DYNAMIC_UNARY_OP";
case UNARY_COMBINED_OP: return "UNARY_COMBINED_OP";
case UNARY_CONVERT: return "UNARY_CONVERT";
case LOGISTIC: return "LOGISTIC";
case CLIPPED_RELU: return "CLIPPED_RELU";
case SWISH: return "SWISH";
case ELU: return "ELU";
case POWER: return "POWER";
case LEAKY_RELU: return "LEAKY_RELU";
case UNARY_ABS: return "UNARY_ABS";
case RELU: return "RELU";
case SOFT_RELU: return "SOFT_RELU";
case SIGMOID: return "SIGMOID";
case TANH: return "TANH";
case GELU: return "GELU";
case SILU: return "SILU";
default: return "Unknown";
}
}
@@ -305,6 +380,7 @@ inline std::string_view toString(ConvFwdSpecialization spec)
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return "FILTER_3x3";
case ODD_C: return "ODD_C";
default: return "Unknown";
}
}

View File

@@ -78,24 +78,27 @@ add_ck_builder_test(test_ckb_conv_builder
test_fwd_instance_traits.cpp
test_bwd_data_instance_traits.cpp
test_instance_traits_util.cpp
unit_device_buffer.cpp
unit_tensor_descriptor.cpp
unit_conv_elementwise_op.cpp
unit_conv_tensor_layout.cpp
unit_conv_tensor_type.cpp
unit_conv_thread_block.cpp
unit_conv_tuning_params.cpp)
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
unit_conv_tuning_params.cpp
unit_conv_fwd_testing.cpp)
target_link_libraries(test_ckb_conv_builder PRIVATE utility)
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
################################################################################
# REGRESSION TESTS - Integration Tests (With Kernel Compilation)
################################################################################
@@ -134,8 +137,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
)
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp)
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
################################################################################

View File

@@ -4,46 +4,83 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using namespace ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::FwdTransfer_4x64x1)
.with_specializations(ckb::ConvFwdSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
{
using enum ck_tile::builder::ConvDirection;
using enum ck_tile::builder::DataType;
using enum ck_tile::builder::TensorLayout;
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = FORWARD,
.data_type = FP16,
.accumulation_data_type = FP32,
.input = {.config = {.layout = GNHWC}},
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v3_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
expected_transfer_parameters,
"Filter1x1Pad0",
"Intrawave",
"v3",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
expected_transfer_parameters,
"Default",
"Intrawave",
"v3",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace
TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
{
if(!ck_tile::get_device_name().starts_with("gfx9"))
{
GTEST_SKIP() << "unsupported architecture";
}
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 16,
.groups = 1,
.input_channels = 32,
.output_channels = 48,
.image =
{
.width = 56,
.height = 64,
},
.filter =
{
.width = 3,
.height = 5,
},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = alloc_inputs(args);
auto outputs = alloc_outputs(args);
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
}

View File

@@ -5,6 +5,7 @@
#include <gmock/gmock.h>
#include <concepts>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>

View File

@@ -4,8 +4,9 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"

View File

@@ -72,14 +72,16 @@ std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
",1" // MaxTransposeTransferSrcScalarPerVector
",1>"; // MaxTransposeTransferDstScalarPerVector
// Test GetInstanceString through base class pointer for backward weight XDL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForBwdWeightGrpConvXdl)
// Test describe() through base class pointer for backward weight XDL variant
TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvXdl)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForBwdWeightGrpConvXdl test once ckr::describe supports

View File

@@ -2,10 +2,11 @@
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
namespace {
@@ -77,14 +78,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
",Default" // LoopScheduler
",1>"; // NumGroupsToMerge
// Test GetInstanceString through base class pointer for standard XDL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConv)
// Test describe() through base class pointer for standard XDL variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConv)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConv)

View File

@@ -71,14 +71,16 @@ std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
",5" // CThreadTransferSrcDstVectorDim
",1>"; // CThreadTransferDstScalarPerVector
// Test GetInstanceString through base class pointer for DL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvDl)
// Test describe() through base class pointer for DL variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvDl)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvDl test once ckr::describe supports DL

View File

@@ -2,10 +2,11 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
namespace {
@@ -76,14 +77,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten
",fp16" // BComputeDataType
",Default>"; // LoopScheduler
// Test GetInstanceString through base class pointer for large tensor variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvLargeTensor)
// Test describe() through base class pointer for large tensor variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvLargeTensor)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvLargeTensor)

View File

@@ -3,6 +3,7 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
@@ -78,14 +79,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
",fp16" // BComputeDataType
",false>"; // DirectLoad
// Test GetInstanceString through base class pointer for V3 variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvV3)
// Test describe() through base class pointer for V3 variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvV3)

View File

@@ -76,14 +76,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
",Default" // LoopSched
",v1>"; // PipelineVer
// Test GetInstanceString through base class pointer for WMMA variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvWmma)
// Test describe() through base class pointer for WMMA variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmma)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvWmma test once ckr::describe supports WMMA

View File

@@ -5,6 +5,8 @@
#include "testing_utils.hpp"
using ck_tile::test::HipError;
using ck_tile::test::HipSuccess;
using ck_tile::test::InstanceMatcher;
using ck_tile::test::InstanceSet;
using ck_tile::test::StringEqWithDiff;
@@ -96,3 +98,12 @@ TEST(InstanceMatcher, ExplainMatchResult)
"Unexpected: 1\n"
"- python\n"));
}
TEST(HipStatusMatcher, Basic)
{
EXPECT_THAT(hipSuccess, HipSuccess());
EXPECT_THAT(hipErrorInvalidValue, HipError(hipErrorInvalidValue));
EXPECT_THAT(hipErrorInvalidValue, Not(HipSuccess()));
EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue)));
EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue)));
}

View File

@@ -11,6 +11,11 @@
#include <vector>
#include <algorithm>
std::ostream& operator<<(std::ostream& os, hipError_t status)
{
return os << hipGetErrorString(status);
}
namespace ck_tile::test {
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
@@ -297,4 +302,41 @@ void InstanceMatcher::DescribeNegationTo(std::ostream* os) const
*os << "is not equal to " << expected_;
}
bool HipStatusMatcher::MatchAndExplain(hipError_t actual,
::testing::MatchResultListener* listener) const
{
(void)listener;
if(actual == expected_)
{
return true;
}
return false;
}
void HipStatusMatcher::DescribeTo(std::ostream* os) const { *os << hipGetErrorString(expected_); }
void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const
{
if(expected_ == hipSuccess)
{
*os << "any error";
}
else
{
*os << "isn't equal to " << hipGetErrorString(expected_);
}
}
::testing::Matcher<hipError_t> HipSuccess()
{
return ::testing::MakeMatcher(new HipStatusMatcher(hipSuccess));
}
::testing::Matcher<hipError_t> HipError(hipError_t error)
{
return ::testing::MakeMatcher(new HipStatusMatcher(error));
}
} // namespace ck_tile::test

View File

@@ -11,6 +11,16 @@
#include <vector>
#include <array>
/// @brief ostream-overload for hipError
///
/// Google Test likes to print errors to ostream, and this provides integration
/// with that. Since we only expect to use this with CK-Builder's own tests,
/// providing this implementation seems not problematic, but if it starts to
/// clash with another implementation then we will need to provide this
/// implementation another way. Unfortunately Google Test does not have a
/// dedicated function to override to provide printing support.
std::ostream& operator<<(std::ostream& os, hipError_t status);
namespace ck_tile::test {
static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); }
@@ -109,4 +119,35 @@ struct InstanceMatcher : public ::testing::MatcherInterface<InstanceSet>
::testing::Matcher<InstanceSet> InstancesMatch(const InstanceSet& expected);
/// @brief Google Test hipError_t matcher.
///
/// This is a custom Google Test matcher implementation which can be used to
/// compare HIP status codes. Use `HipSuccess()` or `HipError()` to obtain
/// an instance.
///
/// @see HipSuccess
/// @see HipError
/// @see ::testing::MatcherInterface
struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
{
HipStatusMatcher(hipError_t expected) : expected_(expected) {}
bool MatchAndExplain(hipError_t actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
hipError_t expected_;
};
/// @brief Construct a Google Test matcher that checks that a HIP operation
/// was successful.
::testing::Matcher<hipError_t> HipSuccess();
/// @brief Construct a Google Test matcher that checks that a HIP operation
/// returned a particular error code.
///
/// @param error The error to expect.
::testing::Matcher<hipError_t> HipError(hipError_t error);
} // namespace ck_tile::test

View File

@@ -0,0 +1,83 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "impl/conv_signature_types.hpp"
#include "testing_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::ElementsAreArray;
using ::testing::NotNull;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::NHWGC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::NHWGK}}};
constexpr ckt::Args<SIGNATURE> ARGS = {
.lengths =
{
.batch_size = 17,
.groups = 5,
.input_channels = 13,
.output_channels = 44,
.image =
{
.width = 99,
.height = 125,
},
.filter =
{
.width = 9,
.height = 4,
},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
using Inputs = ckt::Inputs<SIGNATURE>;
using Outputs = ckt::Outputs<SIGNATURE>;
using UniqueInputs = ckt::UniqueInputs<SIGNATURE>;
using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
TEST(ConvFwdTesting, MakeDescriptors)
{
const auto get_lengths = [](const auto& descriptor) {
const auto lengths = descriptor.get_lengths();
// Google Test cannot print std::span, so turn it into a vector for
// legibility.
return std::vector(lengths.begin(), lengths.end());
};
EXPECT_THAT(get_lengths(ARGS.make_input_descriptor()), ElementsAreArray({5, 17, 13, 125, 99}));
EXPECT_THAT(get_lengths(ARGS.make_weight_descriptor()), ElementsAreArray({5, 44, 13, 4, 9}));
EXPECT_THAT(get_lengths(ARGS.make_output_descriptor()), ElementsAreArray({5, 17, 44, 122, 91}));
}
TEST(ConvFwdTesting, Alloc)
{
auto inputs = alloc_inputs(ARGS);
auto outputs = alloc_outputs(ARGS);
EXPECT_THAT(inputs.get().input, NotNull());
EXPECT_THAT(inputs.get().weight, NotNull());
EXPECT_THAT(outputs.get().output, NotNull());
}

View File

@@ -0,0 +1,81 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ck_tile::test::HipError;
using ck_tile::test::HipSuccess;
using ::testing::Eq;
using ::testing::IsNull;
using ::testing::NotNull;
using ::testing::Throws;
TEST(DeviceBuffer, DefaultToNull)
{
ckt::DeviceBuffer buffer;
EXPECT_THAT(buffer.get(), IsNull());
}
TEST(DeviceBuffer, AllocBuffer)
{
const auto size = 12345;
auto buffer = ckt::alloc_buffer(size);
// Pointer should be non-null
EXPECT_THAT(buffer.get(), NotNull());
// Actually, the pointer should be a device pointer
hipPointerAttribute_t attr;
EXPECT_THAT(hipPointerGetAttributes(&attr, buffer.get()), HipSuccess());
EXPECT_THAT(attr.devicePointer, NotNull());
EXPECT_THAT(attr.type, Eq(hipMemoryTypeDevice));
// Memory should be writable without error
EXPECT_THAT(hipMemset(buffer.get(), 0xFF, size), HipSuccess());
}
TEST(DeviceBuffer, AutoFree)
{
const auto size = 12345;
std::byte* ptr = nullptr;
{
auto buffer = ckt::alloc_buffer(size);
ptr = buffer.get();
}
// Trying to use a pointer after freeing should return en error in HIP.
EXPECT_THAT(hipMemset(ptr, 0xFF, size), HipError(hipErrorInvalidValue));
}
TEST(DeviceBuffer, ThrowsOnOom)
{
const auto size = size_t{1} << 60; // 1 exabyte
auto check = [] { auto buffer = ckt::alloc_buffer(size); };
EXPECT_THAT(check, Throws<ckt::OutOfDeviceMemoryError>());
}
TEST(DeviceBuffer, AllocTensorBuffer)
{
std::vector<size_t> lengths = {128, 128, 128};
std::vector<size_t> strides = {128 * 128, 128, 1};
ckt::TensorDescriptor<ckb::DataType::FP32> descriptor(lengths, strides);
auto buffer = ckt::alloc_tensor_buffer(descriptor);
// Pointer should be non-null
EXPECT_THAT(buffer.get(), NotNull());
// Memory should be writable without error
EXPECT_THAT(hipMemset(buffer.get(), 0xFF, descriptor.get_element_space_size_in_bytes()),
HipSuccess());
}

View File

@@ -0,0 +1,47 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::ElementsAreArray;
using ::testing::Ge;
TEST(TensorDescriptor, Basic)
{
constexpr auto dt = ckb::DataType::FP16;
std::vector<size_t> lengths = {123, 456, 789};
std::vector<size_t> strides = {456 * 789, 789, 1};
ckt::TensorDescriptor<dt> descriptor(lengths, strides);
EXPECT_THAT(descriptor.get_lengths(), ElementsAreArray(lengths));
EXPECT_THAT(descriptor.get_strides(), ElementsAreArray(strides));
}
TEST(TensorDescriptor, ComputeSize)
{
constexpr auto dt = ckb::DataType::FP32;
std::vector<size_t> lengths = {305, 130, 924};
std::vector<size_t> strides = {1000 * 1000, 1, 1000};
ckt::TensorDescriptor<dt> descriptor(lengths, strides);
// Compute the location of the last item in memory, then add one
// to get the minimum size.
size_t expected_size = 1;
for(size_t i = 0; i < lengths.size(); ++i)
{
expected_size += (lengths[i] - 1) * strides[i];
}
EXPECT_THAT(descriptor.get_element_space_size(), Ge(expected_size));
EXPECT_THAT(descriptor.get_element_space_size_in_bytes(),
Ge(expected_size * ckt::data_type_sizeof(dt)));
}

View File

@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
is_gfx12_supported() || is_gfx11_supported();
}
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
template <typename ADataType,
typename BDataType,
index_t MPerXDL64,
index_t NPerXDL64,
index_t MPerXDL32 = MPerXDL64,
index_t NPerXDL32 = NPerXDL64>
inline bool is_xdl_wmma_supported()
{
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
}
else if(is_gfx12_supported() || is_gfx11_supported())
{
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
{
return false;
}

View File

@@ -8,10 +8,16 @@
#include <sstream>
#include <regex>
#include <optional>
#include <memory>
#include "ck/stream_config.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#endif
#endif
#include "ck/utility/get_id.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
@@ -91,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
IsWave64>(); \
}
template <index_t BlockSize_,
index_t MPerBlock_,
index_t NPerBlock_,
index_t MPerXDL_,
index_t NPerXDL_,
index_t MXdlPerWave_,
index_t CShuffleMXdlPerWavePerShuffle_,
index_t CShuffleNXdlPerWavePerShuffle_,
bool IsWave64>
static constexpr auto GetWarpTileConfig()
{
constexpr auto MXdlPerWave64 = MXdlPerWave_;
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
constexpr auto NXdlPerWave =
IsWave64
? GetNXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
if constexpr(IsWave64 == false && NXdlPerWave != 0)
{
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
: CShuffleNXdlPerWavePerShuffle_;
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
return Sequence<16,
16,
MXdlPerWave32,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle32,
CShuffleNXdlPerWavePerShuffle32>{};
}
else
{
return Sequence<MPerXDL_,
NPerXDL_,
MXdlPerWave64,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle_,
CShuffleNXdlPerWavePerShuffle_>{};
}
}
#define INVOKER_RUN_IMPL \
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \
@@ -227,6 +284,12 @@ struct BaseOperator
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
#ifdef CK_EXPERIMENTAL_BUILDER
// Return a description object for this operator, or nullptr if not supported.
virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
#endif
virtual std::string GetInstanceString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }

View File

@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
true>();
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
false>();
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
// GridwiseGemm
template <index_t NXdlPerWave_>
template <typename WarpTileConfig>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
WarpTileConfig::At(0),
WarpTileConfig::At(1),
WarpTileConfig::At(2),
WarpTileConfig::At(3),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
WarpTileConfig::At(4),
WarpTileConfig::At(5),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<ComputeDataType,
ComputeDataType,
MPerXDL,
NPerXDL,
WarpTileConfig32.At(0),
WarpTileConfig32.At(1)>())
{
return false;
}
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#endif
@@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#endif
@@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
@@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#endif
@@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#endif
@@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#endif
@@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(
ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
};

View File

@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
// Validate stride requirements for SplitK (k_batch > 1)
// TODO: Enable splitK
if(a.k_batch > 1)
{
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(a.StrideC != a.N)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For RowMajor layout: StrideC must equal N."
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
}
return false;
}
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
if(a.StrideC != a.M)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For ColumnMajor layout: StrideC must equal M."
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
}
return false;
}
}
}
bool group_arg_valid = false;
if(isWave64)
{

View File

@@ -527,11 +527,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
}
else
{
#if defined(__gfx11__)
// TODO: remove this restriction
static_assert(ScaleBlockM >= MPerWmma,
"ScaleBlockM must be greater equal than MPerWmma");
#endif
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::

View File

@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1011__)
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
#endif
#if defined(__gfx1012__)
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
#endif
#if defined(__gfx1013__)
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
#endif
#if defined(__gfx10_1_generic__)
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
#endif // __gfx10_1_generic__
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
{
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfBound(m, n))
if(mask.IsOutOfSinkBound(m, n))
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}

View File

@@ -34,77 +34,80 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
AccDataType v_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
constexpr std::size_t kGroupK = QuantGroupSize::kK;
// ---- A loader: dequant A(m,k) into AccDataType ----
auto load_a = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
};
// ---- B loader: dequant B(k,n) into AccDataType ----
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
};
// Apply group dequant scale
if((k + 1) % QuantGroupSize::kK == 0)
// ---- scale loader for a given K-group index ----
auto load_scale = [&](ck_tile::index_t k_group) -> float {
const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
return q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
return fp8_to_float_raw(q(outer_dim, inner_dim));
}
else // QDataType == bf8_t by static_assert above
{
return bf8_to_float_raw(q(outer_dim, inner_dim));
}
};
// ---- Loop over K by groups (full and tail) ----
for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
{
const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
AccDataType v_block_acc = 0;
// unscaled accumulation within this K-group
for(std::size_t k = k_begin; k < k_end; ++k)
{
const AccDataType v_a = load_a(k);
const AccDataType v_b = load_b(k);
v_block_acc += v_a * v_b;
}
const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
const float scale = load_scale(k_group);
v_acc += v_block_acc * scale;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));

View File

@@ -84,7 +84,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -94,10 +94,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -114,18 +114,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
@@ -145,22 +151,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -177,17 +183,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile

View File

@@ -8,7 +8,8 @@
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
{
Atomic = 0u,
Reduction = 1u
Atomic = 0u,
Reduction = 1u,
TreeReduction = 2u
};
} // namespace ck_tile

View File

@@ -35,7 +35,8 @@ template <typename AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters check and aligment
*
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
*/
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
{
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
return shuffle_tile;
}
/**
* @brief Shuffle tile configuration parameters
*
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
if constexpr(elem_per_thread <= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
static_assert(elem_per_thread % GetVectorSizeC() == 0);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
kMPerBlock / (MPerXdl * MWave)),
1>();
}
else
{
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
return AlignShuffleTileWithSmem<1,
min(num_xdl_shuffles,
kNPerBlock / (NPerXdl * NWave))>();
}
}
}();

View File

@@ -86,21 +86,22 @@ struct GenericAttentionMask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
: GenericAttentionMask(0, 0, y_total_, x_total_)
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -141,6 +142,44 @@ struct GenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -195,6 +234,30 @@ struct GenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
if constexpr(IsLocal)
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end;
}
else
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -237,7 +300,7 @@ struct GenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
@@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask
ck_tile::min(origin_end, split_end));
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; // 128
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
const index_t start = ck_tile::max(origin_start, split_start);
const index_t end = ck_tile::min(origin_end, split_end);
const bool is_first_intersecting_split =
(split_start <= origin_start && split_end >= origin_start);
const bool sink_in_range = (sink_seq_end <= start);
const index_t sink_offset =
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
return ck_tile::make_tuple(sink_offset, start, end);
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -620,6 +751,7 @@ static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Ma
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
@@ -637,7 +769,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, sink_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
}
template <typename MaskType>
@@ -649,7 +795,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
left_size, right_size, 0, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
}
} // namespace ck_tile

View File

@@ -162,6 +162,17 @@ struct StandardAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
template <bool UseExp2 = false>
@@ -224,6 +235,17 @@ struct LogitsSoftCap
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
constexpr uint32_t CUSTOM_MASK = 1U;
@@ -297,6 +319,17 @@ struct ComposedAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
} // namespace ck_tile

View File

@@ -200,7 +200,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -356,6 +356,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -418,6 +419,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -497,6 +499,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -557,6 +560,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -1008,6 +1012,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -58,6 +58,7 @@ struct FmhaFwdKernel
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -155,7 +156,7 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -335,6 +336,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -393,6 +395,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -481,6 +484,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -529,6 +533,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -580,6 +585,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -628,6 +634,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -673,6 +680,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -732,6 +740,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -817,6 +826,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -861,6 +871,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -908,6 +919,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -952,6 +964,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -1443,6 +1456,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
@@ -2182,6 +2196,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
return MakeKargsImpl(q_ptr,
@@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type);
}
@@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_v,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q);
}
@@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
@@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel
struct MaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return o_acc;
}
}
// k_dram_block_window
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
@@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
// v_dram_window
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
@@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {kN0, 0}));
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
@@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto row, auto col) {
return mask.IsOutOfSinkBound(row, col);
});
}
else
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
}
@@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
// tail
{
block_sync_lds();

View File

@@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
// check early exit if no work to do
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
@@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
@@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
@@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
@@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
@@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);

View File

@@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
@@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {kN0, 0}));
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
@@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
// tail
{
block_sync_lds();

View File

@@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -114,6 +115,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -167,6 +169,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
// extract tile size attributes to remove dependency on traits

View File

@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
@@ -233,10 +234,26 @@ struct BlockFmhaPipelineQRKSVS
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(kHasDropout)
{
// K and dropout use the same address in LDS, finish loading from k_lds_window by
// gemm_0 to reuse LDS.
block_sync_lds();
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
block_sync_lds();
@@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS
});
}
// move K tile windows
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{

View File

@@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
@@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
@@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
const auto p = [&]() {
@@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&

View File

@@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasUnevenSplits = true;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||

View File

@@ -20,8 +20,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -36,6 +37,7 @@ struct TileFmhaTraits
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
@@ -65,8 +67,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kIsPagedKV_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
struct TileFmhaFwdPagedKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -81,6 +84,7 @@ struct TileFmhaFwdPagedKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
@@ -95,7 +99,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kIsPagedKV_,
bool kHasUnevenSplits_,
bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kHasSink_ = false>
struct TileFmhaFwdSplitKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -112,6 +117,7 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,

View File

@@ -986,6 +986,8 @@ struct MoeSortingKernel
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
__syncthreads();
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights

View File

@@ -33,9 +33,10 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"

View File

@@ -232,7 +232,7 @@ struct BatchedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr1[GetSmemSize()];
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
UniversalGemmKernel::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},

View File

@@ -310,7 +310,7 @@ struct GroupedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -561,6 +561,7 @@ struct GroupedGemmKernel
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_sync_lds();
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <typename CompilerTarget, typename Enabler = void>
struct StreamKCoherency
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::coherence_default;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX942,
core::arch::amdgcn_target_id::GFX950>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::SYSTEM_NT0;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX908,
core::arch::amdgcn_target_id::GFX90A>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::glc_slc;
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "streamk_gemm_coherency.hpp"
namespace ck_tile {
@@ -318,37 +319,58 @@ struct StreamKKernel
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
* CTA index.
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
* set by the given CTA index.
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates the
* output block tile with accumulated values.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
@@ -370,8 +392,8 @@ struct StreamKKernel
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for loading
* the partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
@@ -405,8 +427,8 @@ struct StreamKKernel
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing the
* partial block tile.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
@@ -420,7 +442,10 @@ struct StreamKKernel
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
@@ -431,8 +456,11 @@ struct StreamKKernel
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
@@ -483,7 +511,8 @@ struct StreamKKernel
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
@@ -528,46 +557,107 @@ struct StreamKKernel
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
if constexpr(TilePartitioner::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile =
kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta =
kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
for(index_t stride = 1;; stride <<= 1)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
while(accum_iters < iter_per_tile)
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C
// tensor.
if(tile_started && !partner_in_tile)
{
WaitStorePartialDone(kargs, next_cta);
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;
}
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
// It's this workgroup's turn to read from partials.
if(tile_local_cta_idx % (stride << 1) == 0)
{
// If this workgroup's partner is in the tile then it can read from
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
// workgroups, except the workgroup who starts the tile, will write to
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
{
static_assert(
"An implementation does not exist for the chosen reduction strategy.");
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
@@ -631,6 +721,7 @@ struct StreamKKernel
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
@@ -639,10 +730,10 @@ struct StreamKKernel
private:
/**
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
* the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
* of A and B.
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
* is the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the
* layouts of A and B.
* @note The default case is that A is assumed to be row major and B is assumed to be column
* major.
*/
@@ -679,15 +770,16 @@ struct StreamKKernel
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
}
/**
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
* kernel
* @return The occupancy
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
@@ -700,7 +792,7 @@ struct StreamKKernel
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return max(occupancy, 1);

View File

@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start iteration for the given the cta_idx.
* @param cta_idx The current Stream-K workgroup's index.
* @return index_t The start iteration.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
* @brief Calculates the workgroup's local CTA idx within the given tile.
*
* @param tile_iter_start The starting tile iteration.
* @param cta_idx The Stream-K workgroup index.
* @return index_t The tile local workgroup index in the tile.
*/
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
index_t cta_idx) const noexcept;
/**
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.

View File

@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
index_t cta_idx) const noexcept
{
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
// it is extra_iters.
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
iter = get_start_iter(cta_idx);
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
index_t tile_iter_start, index_t cta_idx) const noexcept
{
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
// Compute how many WGs fit before this tile starts assuming each WG does an
// extra_iter
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
// Compute how many WGs fit before this tile starts excluding extra iters
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
// Compute the CTA idx for the CTA that starts this tile
const index_t coop_group_start =
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
return cta_idx - coop_group_start;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();

View File

@@ -280,7 +280,7 @@ struct UniversalGemmKernel
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&

View File

@@ -9,11 +9,35 @@
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV1
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }

Some files were not shown because too many files have changed in this diff Show More