mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Merge branch 'develop' into lwpck-4181
This commit is contained in:
@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -65,7 +65,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
# there is no corresponding instance for parameters).
|
||||
if(BUILD_TESTING)
|
||||
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
|
||||
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
|
||||
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
|
||||
@@ -76,7 +76,8 @@ using fmha_traits = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
{F_skip},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -113,7 +114,7 @@ using fmha_kernel = {F_kernel}<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
|
||||
using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
@@ -229,9 +230,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
|
||||
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -278,13 +279,14 @@ class FmhaFwdApiTrait:
|
||||
dvpad: str
|
||||
skip: str
|
||||
tr_load: str
|
||||
sink: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -384,6 +386,7 @@ class FmhaFwdPipeline:
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_trload: str # true/false
|
||||
F_sink: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -454,6 +457,10 @@ class FmhaFwdPipeline:
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
return n
|
||||
|
||||
@@ -543,6 +550,7 @@ class FmhaFwdApiPool:
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
@@ -683,6 +691,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
|
||||
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -725,6 +734,7 @@ class FmhaFwdKernel:
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
@@ -957,52 +967,55 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
pipelines = []
|
||||
if dtype in cls._DT_FP32:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP16_BF16:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["no", "pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
@@ -1033,13 +1046,14 @@ class KernelComponentFactoryGfx950(
|
||||
)
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
if (
|
||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||
@@ -1048,15 +1062,15 @@ class KernelComponentFactoryGfx950(
|
||||
and dropout == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for mask in ["no", "causal"]:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
|
||||
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
return pipelines
|
||||
|
||||
@@ -1105,23 +1119,24 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
pipelines = []
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
|
||||
@@ -73,7 +73,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_pagedkv},
|
||||
kHasUnevenSplits,
|
||||
kMergeNumHeadGroupsSeqLenQ,
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
@@ -117,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
}} // anonymous namespace
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#pragma clang diagnostic push
|
||||
@@ -279,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
@@ -332,6 +333,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
pagedkv: str
|
||||
sink: str # sink or not
|
||||
bn1comb: int # tile size along v head_dim of combine kernel
|
||||
|
||||
@property
|
||||
@@ -339,7 +341,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
|
||||
+ f"{self.dvpad}-{self.pagedkv}"
|
||||
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -425,6 +427,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
F_lse: str #
|
||||
F_squant: str #
|
||||
F_pagedkv: str # t/f
|
||||
F_sink: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
@@ -485,6 +488,10 @@ class FmhaFwdSplitKVPipeline:
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
return n
|
||||
|
||||
|
||||
@@ -567,6 +574,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
@@ -667,6 +675,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
@@ -740,19 +749,23 @@ class KernelComponentFactoryBase:
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
|
||||
for logits, mask, bias, pagedkv, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -908,6 +921,7 @@ def get_fwd_splitkv_blobs(
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -917,6 +931,7 @@ def get_fwd_splitkv_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
@@ -1075,6 +1090,7 @@ def write_blobs(
|
||||
lse=kernel.F_pipeline.F_lse,
|
||||
squant=kernel.F_pipeline.F_squant,
|
||||
pagedkv=kernel.F_pipeline.F_pagedkv,
|
||||
sink=kernel.F_pipeline.F_sink,
|
||||
spad=kernel.F_pipeline.F_spad,
|
||||
skpad=kernel.F_pipeline.F_skpad,
|
||||
dpad=kernel.F_pipeline.F_dpad,
|
||||
|
||||
@@ -65,7 +65,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
||||
{F_pagedkv}, //pagedkv
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
{F_skip},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -100,7 +101,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
@@ -129,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
|
||||
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -163,12 +164,13 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
sink: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -256,6 +258,7 @@ class FmhaFwdPipeline:
|
||||
F_squant: str #
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_sink: str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -320,6 +323,10 @@ class FmhaFwdPipeline:
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
return n
|
||||
|
||||
@@ -363,6 +370,7 @@ class FmhaFwdApiPool:
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
@@ -480,6 +488,7 @@ class FmhaFwdKernel:
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -526,6 +535,7 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
)
|
||||
|
||||
|
||||
@@ -539,22 +549,23 @@ class KernelComponentFactoryBase:
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(
|
||||
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t"],
|
||||
["f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
pass # TODO
|
||||
else:
|
||||
@@ -678,6 +689,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -687,6 +699,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
|
||||
@@ -266,6 +266,7 @@ struct fmha_fwd_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
@@ -352,6 +353,7 @@ struct fmha_fwd_pagedkv_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
};
|
||||
@@ -442,6 +444,7 @@ struct fmha_fwd_splitkv_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
};
|
||||
|
||||
@@ -561,6 +564,7 @@ struct fmha_batch_prefill_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
float p_drop;
|
||||
@@ -613,6 +617,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
@@ -663,6 +668,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
@@ -824,6 +830,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
}
|
||||
@@ -869,6 +876,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
@@ -935,6 +943,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
else
|
||||
@@ -982,6 +991,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
@@ -1142,6 +1152,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
@@ -1194,6 +1205,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
@@ -1228,7 +1240,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -1254,6 +1267,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1280,7 +1294,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false>
|
||||
struct fmha_fwd_pagedkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -1305,6 +1320,7 @@ struct fmha_fwd_pagedkv_traits_
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1327,6 +1343,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasSink_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
@@ -1354,6 +1371,7 @@ struct fmha_fwd_splitkv_traits_
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1440,6 +1458,7 @@ struct fmha_fwd_traits
|
||||
bool has_dropout;
|
||||
quant_scale_enum qscale_type;
|
||||
bool skip_min_seqlen_q = false;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
@@ -1458,6 +1477,7 @@ struct fmha_fwd_pagedkv_traits
|
||||
bool use_pagedkv = true;
|
||||
bool do_fp8_static_quant = false;
|
||||
bool skip_min_seqlen_q = false;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -1477,6 +1497,7 @@ struct fmha_fwd_splitkv_traits
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
|
||||
@@ -879,6 +879,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_sink = mask.sink > 0 ? true : false;
|
||||
traits.has_lse = lse;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
||||
@@ -1042,6 +1043,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.sink_size = mask.sink;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
@@ -1645,7 +1647,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1657,6 +1659,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
mask.sink,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
@@ -1666,6 +1669,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
mask.sink,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
|
||||
@@ -25,6 +25,7 @@ struct mask_info
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
ck_tile::index_t sink;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
@@ -58,13 +59,14 @@ struct mask_info
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
ck_tile::index_t sink_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
left_size, right_size, sink_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
@@ -79,27 +81,54 @@ struct mask_info
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
||||
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
auto found_2 = v.find(',', found_1 + 1);
|
||||
ck_tile::index_t v1 = 0;
|
||||
ck_tile::index_t sink = 0;
|
||||
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
v0, v1, sink, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
v0, v1, sink, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
@@ -108,6 +137,7 @@ struct mask_info
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -126,6 +156,7 @@ struct mask_info
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
@@ -134,6 +165,7 @@ struct mask_info
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
77
example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh
Normal file
77
example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
|
||||
TEST_SPLITKV=0
|
||||
TEST_APPENDKV=0
|
||||
# options:
|
||||
# -s: run splitkv tests
|
||||
# -a: run appendkv tests
|
||||
while getopts ":sa" opt; do
|
||||
case "${opt}" in
|
||||
s)
|
||||
TEST_SPLITKV=1
|
||||
;;
|
||||
a)
|
||||
TEST_APPENDKV=1
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
local CACHE_BATCH_IDX="0"
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
||||
fi
|
||||
|
||||
for prec in "fp16"; do
|
||||
for mode in 1 0 ; do
|
||||
for perm in 0 1 ; do
|
||||
for vlayout in "r" "c" ; do
|
||||
for batch in 1 4; do
|
||||
for head in 1; do
|
||||
for h_k in 1; do
|
||||
for q_seq in 128 512 ; do
|
||||
for kv_seq in 128 1024; do
|
||||
for hdim in 32 64 128 256; do #256
|
||||
for lse in 0 1 ; do
|
||||
for bias in "e" ; do
|
||||
for p_drop in 0.0 0.2; do # 0.0
|
||||
for mask in "t:2,0,4" "b:1,0,2"; do
|
||||
for num_splits in $NUM_SPLITS ; do
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ; done
|
||||
}
|
||||
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
|
||||
set +x
|
||||
@@ -39,6 +39,7 @@ function print_log_header(){
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
|
||||
86
example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Executable file
86
example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
set -x
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
||||
|
||||
# window_size[2,0], sink_size = 2
|
||||
|
||||
# x=1/y=3
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
||||
|
||||
# x=4/y=1
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
||||
|
||||
# x=4/y=-1
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
||||
|
||||
# x=-1/y=5
|
||||
|
||||
# * * * * * * * * * * * *
|
||||
# * * * * * * * * * * * *
|
||||
# 1 * * * * * 1 * * * * *
|
||||
# 1 1 * * * * 1 1 * * * *
|
||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
||||
# * 1 1 1 * * 1 1 1 1 * *
|
||||
# * * 1 1 1 * 1 1 1 1 1 *
|
||||
# * * * 1 1 1 1 1 * 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
||||
# x=-1/y=8
|
||||
# * * * * * * * * * *
|
||||
# * * * * * * * * * *
|
||||
# 1 * * * * ----> 1 * * * *
|
||||
# 1 1 * * * 1 1 * * *
|
||||
# 1 1 1 * * 1 1 1 * *
|
||||
# 1 1 1 1 * 1 1 1 1 *
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
@@ -12,40 +12,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -459,7 +433,7 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args()
|
||||
inline auto create_args()
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
|
||||
@@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
|
||||
parse_gemm_size(ck_tile::ArgParser& arg_parser)
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_gemm_size(
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
|
||||
@@ -63,25 +63,30 @@ struct UniversalInvoker
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -11,40 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -11,24 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
|
||||
@@ -10,40 +10,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
@@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
return combined_hash;
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -91,7 +58,7 @@ struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -391,20 +391,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
ck_tile::index_t AQK, BQK, BQN = 0;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
AQK = ck_tile::integer_divide_ceil(
|
||||
K, QuantGroupSize::kK); // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
@@ -7,46 +7,46 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
struct GemmConfigurationBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool PAD_M = true;
|
||||
static constexpr bool PAD_N = true;
|
||||
static constexpr bool PAD_K = true;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr bool PERMUTE_A = false;
|
||||
static constexpr bool PERMUTE_B = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool TRANSPOSE_C = false;
|
||||
static constexpr bool USE_STRUCTURED_SPARSITY = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr int BLOCK_PER_CU = 1;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
|
||||
static constexpr bool PRESHUFFLE = false;
|
||||
static constexpr bool DOUBLE_SMEM_BUFFER = false;
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent_>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
template <typename PrecisionType, bool IsPersistent>
|
||||
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 16;
|
||||
static constexpr ck_tile::index_t M_TILE = 256;
|
||||
static constexpr ck_tile::index_t N_TILE = 256;
|
||||
static constexpr ck_tile::index_t K_TILE = 16;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
static constexpr ck_tile::index_t M_WARP = 2;
|
||||
static constexpr ck_tile::index_t N_WARP = 2;
|
||||
static constexpr ck_tile::index_t K_WARP = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
static constexpr ck_tile::index_t M_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t N_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool PERSISTENT = IsPersistent;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
|
||||
struct StreamKGemmTypeConfig
|
||||
struct StreamKGemmTypeConfiguration
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
@@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
auto createArgs(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "512", "m dimension")
|
||||
|
||||
@@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout)
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
auto calculateRtolAtol(const ck_tile::index_t k_dim,
|
||||
const ck_tile::index_t k_batch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto relative_tolerance =
|
||||
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
const auto absolute_tolerance =
|
||||
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
// Calculate error due to multiple WGs working in the same C macro tile
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
const auto relative_tolerance_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
|
||||
const auto absolute_tolerance_split_k =
|
||||
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
|
||||
k_batch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
|
||||
std::max(absolute_tolerance, absolute_tolerance_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -45,102 +49,107 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, ck_tile::index_t>
|
||||
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
|
||||
ck_tile::DeviceMem& b_k_n_device_memory,
|
||||
ck_tile::DeviceMem& c_m_n_device_memory,
|
||||
ck_tile::index_t m_dim,
|
||||
ck_tile::index_t n_dim,
|
||||
ck_tile::index_t k_dim,
|
||||
ck_tile::index_t stride_a,
|
||||
ck_tile::index_t stride_b,
|
||||
ck_tile::index_t stride_c,
|
||||
int warmup_iterations,
|
||||
int repeat_iterations,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
|
||||
b_k_n_device_memory.GetDeviceBuffer(),
|
||||
c_m_n_device_memory.GetDeviceBuffer(),
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c};
|
||||
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
std::tuple<float, ck_tile::index_t> average_time_and_batch;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time_and_batch = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
else /*Reduction*/
|
||||
{
|
||||
ave_time_and_batch = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
|
||||
return ave_time_and_batch;
|
||||
return average_time_and_batch;
|
||||
}
|
||||
|
||||
template <typename CDataType>
|
||||
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
const ck_tile::tuple<double, double>& rtol_atol,
|
||||
const char* variant)
|
||||
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
|
||||
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
|
||||
const char* variant)
|
||||
{
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_ref,
|
||||
bool pass = ck_tile::check_err(c_m_n_device_result,
|
||||
c_m_n_reference,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
relative_absolute_tolerances.at(ck_tile::number<0>{}),
|
||||
relative_absolute_tolerances.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "Relative error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
|
||||
<< std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy)
|
||||
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
|
||||
{
|
||||
if(strategy == "atomic")
|
||||
{
|
||||
@@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename TypeConfiguration,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
int runGemmExampleWithLayouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
static_assert(!GemmConfig::Preshuffle, "Not implemented");
|
||||
static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented");
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
static_assert(!GemmConfig::PermuteB, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
|
||||
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
using ADataType = typename TypeConfiguration::ADataType;
|
||||
using BDataType = typename TypeConfiguration::BDataType;
|
||||
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
|
||||
using CDataType = typename TypeConfiguration::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
ck_tile::index_t m_dim = arg_parser.get_int("m");
|
||||
ck_tile::index_t n_dim = arg_parser.get_int("n");
|
||||
ck_tile::index_t k_dim = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
|
||||
int warmup_iterations = arg_parser.get_int("warmup");
|
||||
int repeat_iterations = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool flush_cache = arg_parser.get_bool("flush_cache");
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
|
||||
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
|
||||
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
|
||||
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::HostTensor<ADataType> a_m_k_host(
|
||||
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_host(
|
||||
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_device_result(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
a_m_k_host.SetZero();
|
||||
b_k_n_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
a_m_k_device_memory.ToDevice(a_m_k_host.data());
|
||||
b_k_n_device_memory.ToDevice(b_k_n_host.data());
|
||||
c_m_n_device_memory.SetZero();
|
||||
c_m_n_device_result.SetZero();
|
||||
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_device_memory,
|
||||
b_k_n_device_memory,
|
||||
c_m_n_device_memory,
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
warmup_iterations,
|
||||
repeat_iterations,
|
||||
flush_cache,
|
||||
reduction_strategy);
|
||||
|
||||
auto [ave_time, num_wgs_per_tile] = invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
flush_cache,
|
||||
reduction_strategy);
|
||||
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
|
||||
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
|
||||
sizeof(CDataType) * m_dim * n_dim;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / average_time;
|
||||
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
|
||||
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
bool pass = false;
|
||||
|
||||
// Memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_ref.SetZero();
|
||||
ck_tile::HostTensor<CDataType> c_m_n_reference(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
c_m_n_reference.SetZero();
|
||||
|
||||
if(arg_parser.get_int("v") == 1) // Validate on the CPU
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
a_m_k_host, b_k_n_host, c_m_n_reference);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
|
||||
{
|
||||
// Memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
|
||||
c_m_n_reference.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buffer_reference.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
|
||||
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -17,43 +17,49 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
const ck_tile::stream_config& stream_config)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
|
||||
GemmConfiguration::N_TILE,
|
||||
GemmConfiguration::K_TILE>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::K_WARP>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE>,
|
||||
GemmConfiguration::PERMUTE_A,
|
||||
GemmConfiguration::PERMUTE_B>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
|
||||
using TilePartitioner = ck_tile::
|
||||
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
|
||||
GemmConfiguration::PAD_N,
|
||||
GemmConfiguration::PAD_K,
|
||||
GemmConfiguration::DOUBLE_SMEM_BUFFER,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfiguration::TRANSPOSE_C,
|
||||
GemmConfiguration::USE_STRUCTURED_SPARSITY,
|
||||
GemmConfiguration::PERSISTENT,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
@@ -61,39 +67,39 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation.value,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
@@ -109,7 +115,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
@@ -120,45 +126,47 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
ck_tile::index_t num_wgs_per_tile =
|
||||
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the same
|
||||
// output tile in the C tensor, so we must atomic add
|
||||
// the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the
|
||||
// same output tile in the C tensor, so we must
|
||||
// atomic add the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing final
|
||||
// results to each macro tile in the C tensor, so we
|
||||
// can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing
|
||||
// final results to each macro tile in the C
|
||||
// tensor, so we can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
template <typename GemmConfiguration, typename TypeConfiguration>
|
||||
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig>(
|
||||
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType, bool Persistent_> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
|
||||
int runGemmExample(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_dp = arg_parser.get_bool("persistent_dp");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -247,5 +257,5 @@ int run_gemm_example(int argc, char* argv[])
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigMemoryInterwave>(argc, argv);
|
||||
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user