mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge branch 'develop' into tianxing/unified-attention
This commit is contained in:
@@ -36,6 +36,19 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -604,23 +647,42 @@ class KernelComponentFactory:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -672,69 +734,75 @@ def get_fwd_blobs(
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
@@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
|
||||
and kernel_ctx.tile.F_bm0 != 128
|
||||
)
|
||||
or (
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
|
||||
and kernel_ctx.pipeline.tag != "qr_async"
|
||||
and kernel_ctx.tile.F_bk0 == 64
|
||||
)
|
||||
):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
@@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
|
||||
@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
@@ -230,6 +230,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -529,14 +534,25 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
// SGLang-style page table
|
||||
int32_t num_total_pages;
|
||||
void* kv_indptr;
|
||||
void* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
void* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#endif
|
||||
// KV cache page table fields (kv_lookup_table selects interpretation):
|
||||
// - SGLANG_PAGE_TABLE_1D:
|
||||
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
|
||||
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
|
||||
// kv_last_page_lens: per-batch last page lengths [batch]
|
||||
// - VLLM_BLOCK_TABLE_2D:
|
||||
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
|
||||
// batch_stride_block_table: row stride for block_table
|
||||
// seqlen_k_ptr: per-batch seqlen_k [batch]
|
||||
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
|
||||
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
|
||||
kv_memory_layout; // KV memory layout (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
|
||||
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
|
||||
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
|
||||
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
|
||||
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
|
||||
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
@@ -627,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -677,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -837,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
args.min_seqlen_q,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -882,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -949,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -997,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1113,6 +1135,22 @@ template <typename FmhaKernel>
|
||||
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
|
||||
const PageTableKargs page_table = [&]() {
|
||||
if constexpr(FmhaKernel::kKVLookupTable ==
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
|
||||
}
|
||||
else
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
args.batch_stride_block_table,
|
||||
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
|
||||
}
|
||||
}();
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
@@ -1133,12 +1171,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1164,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1184,12 +1219,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1220,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1281,6 +1313,65 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
BiasEnum_,
|
||||
kStoreLse_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
@@ -1527,7 +1618,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
using fmha_batch_prefill_traits = fmha_fwd_traits;
|
||||
struct fmha_batch_prefill_traits : public fmha_fwd_traits
|
||||
{
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
int page_size = 1;
|
||||
};
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits,
|
||||
fmha_batch_prefill_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename SMPLComputeDataType>
|
||||
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
||||
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
||||
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t real_seqlen_q,
|
||||
ck_tile::index_t real_seqlen_k)
|
||||
{
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
||||
}
|
||||
// Append sink token at the end of each row
|
||||
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
int init_sink_value,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
||||
bias_host);
|
||||
}
|
||||
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
||||
// for close to rowmax values.
|
||||
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
||||
sink_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
if(init_sink_value != 0)
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
else
|
||||
args.sink_ptr = nullptr;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -1351,8 +1387,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
||||
if(lse)
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
// Create extended tensor with sink token
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
// Copy original attention scores and append sink values
|
||||
copy_attention_scores_with_sink(
|
||||
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
||||
|
||||
// Compute softmax on extended tensor
|
||||
ck_tile::HostTensor<PDataType> p_extended(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func);
|
||||
}
|
||||
|
||||
// Extract only the original columns (exclude sink token column)
|
||||
p_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
// No sink tokens - compute softmax directly
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
|
||||
@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
@@ -69,107 +69,88 @@ struct BasicInvoker
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
args.stride_E);
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&]() {
|
||||
// use SET operation since each K-split writes to separate memory
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
return Run();
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -460,12 +460,6 @@ inline auto create_args()
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
// Type aliases for memory operation integral constants
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -57,114 +57,95 @@ struct WeightPreshuffleInvoker
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time =
|
||||
ck_tile::launch_kernel_time_mask(s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,112 +60,94 @@ struct UniversalInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Threadwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Blockwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "19", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = float;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
// Operations: one doing a sum reduction, the other computing the mean square
|
||||
// In the case of mean square:
|
||||
// 1. The element wise operation squares each element before reduction
|
||||
// 2. The reduction operation sum the squared element
|
||||
// 3. The accumulator element wise operation divides the result by the total number of reduced
|
||||
// elements (intra block operation)
|
||||
// 4. The partial result is updated across blocks using inter block reduction, a sum.
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
|
||||
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise
|
||||
// ops
|
||||
auto accumulator_elementwise_ops = ck_tile::make_tuple(
|
||||
ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
|
||||
auto inter_block_reduce_ops = ck_tile::make_tuple(
|
||||
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// Determine block group size for multi-block reduction
|
||||
// block_group_size records how many blocks participate to a reduction (input data dependent)
|
||||
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
|
||||
// to process the whole reduction, each thread will to process multiple thread tile
|
||||
// a num_block_tile_iterations times
|
||||
auto [num_block_tile_iterations, block_group_size] =
|
||||
typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams();
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
ck_tile::index_t kGridSize =
|
||||
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;
|
||||
|
||||
std::cout << "Block group size: " << block_group_size
|
||||
<< ", Num block tile iterations: " << num_block_tile_iterations
|
||||
<< ", Reduce total length: " << reduce_total_length << std::endl;
|
||||
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
// Init the output data with identity values respective to each reduce op
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
constexpr auto op = reduce_ops.at(i);
|
||||
const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
const auto output_number_elements = N * C;
|
||||
std::fill(h.begin() + i * output_number_elements,
|
||||
h.begin() + (i + 1) * output_number_elements,
|
||||
identity_val);
|
||||
});
|
||||
|
||||
auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };
|
||||
|
||||
float ave_time = launch_kernel_time_mask(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
clear_output_buffer,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops)
|
||||
|
||||
);
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops,
|
||||
block_group_size);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
std::cout << "Checking operation " << i << ": " << std::endl;
|
||||
|
||||
bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
|
||||
if(pass_op)
|
||||
{
|
||||
std::cout << "✅ valid results for this operation" << std::endl;
|
||||
}
|
||||
pass &= pass_op;
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "7", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
// Two operations: one do a sum reduction, the other computing the mean square
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops
|
||||
auto elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise ops
|
||||
auto accumulator_elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementiwise ops on reduction,
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceThreadWise<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
@@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
printf("moe_buf:%" PRIu64 "(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
printf("moe_buf:%" PRIu64 ", ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
@@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
quant_grouped_gemm_bf8_rowcol.cpp
|
||||
quant_grouped_gemm_bf8_tensor.cpp
|
||||
)
|
||||
|
||||
add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp)
|
||||
add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
@@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "abquant_grouped_gemm.hpp"
|
||||
|
||||
// Non-persistent grouped gemm for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
// Persistent grouped gemm tileloop for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC>;
|
||||
|
||||
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_abquant_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = run_abquant_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <bool Persistent_>
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
// ABQuant specialization for GemmQuantConfig
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.")
|
||||
.insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
// Forward declaration of the non-persistent version
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
|
||||
// Forward declaration of the tileloop version for persistent kernels
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
@@ -62,71 +62,55 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -139,8 +123,7 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(!splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ave_time =
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -328,5 +328,4 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
void* kargs_ptr);
|
||||
|
||||
@@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -142,8 +126,7 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
if(!splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_multi_d_example.inc"
|
||||
|
||||
@@ -65,70 +65,54 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -141,8 +125,7 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -72,10 +72,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
@@ -137,8 +136,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
@@ -224,90 +222,79 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,604 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_abquant_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
ave_time = grouped_gemm_abquant<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_abquant_grouped_gemm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
|
||||
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
|
||||
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
if(!valid_input_data(
|
||||
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
|
||||
// Clear existing (invalid) data before adding defaults
|
||||
Ms.clear();
|
||||
Ns.clear();
|
||||
Ks.clear();
|
||||
stride_As.clear();
|
||||
stride_Bs.clear();
|
||||
stride_Cs.clear();
|
||||
stride_AQs.clear();
|
||||
stride_BQs.clear();
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
// Let get_default_stride calculate based on layout
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
|
||||
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
aq_tensors.reserve(group_count);
|
||||
bq_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
aq_dev_buf.reserve(group_count);
|
||||
bq_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
// For ABQuantGrouped, both A and B need quantization
|
||||
static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
"This file only supports ABQuantGrouped mode");
|
||||
|
||||
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize
|
||||
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
aq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_aq,
|
||||
p_bq,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stride_Cs[i],
|
||||
stride_AQs[i],
|
||||
stride_BQs[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_abquant_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
|
||||
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
|
||||
sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Reference implementation for ABQuantGrouped
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &=
|
||||
ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_grouped_gemm_json_results<ALayout, BLayout, CLayout>(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
group_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig, typename BQuantGroupSize>
|
||||
int run_abquant_grouped_gemm_example_prec_type_with_bquant(
|
||||
std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig>
|
||||
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(bquant_group_size == "1x1x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else if(bquant_group_size == "1x128x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType>
|
||||
int run_abquant_gemm_example_persistency(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
bool persistent,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, true>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, false>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string c_layout = arg_parser.get_str("c_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
const std::string bquant_group_size = arg_parser.get_str("bquant_group_size");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
@@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup,
|
||||
// earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
|
||||
@@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType>(stream, group_count, kargs_ptr, splitk);
|
||||
CDataType>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup,
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::GemmTransKernelArg<NumDTensor>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr},
|
||||
@@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup,
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time =
|
||||
grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
|
||||
ave_time = grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -207,7 +204,6 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -150,7 +147,6 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
@@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
@@ -159,7 +156,6 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -265,23 +261,7 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
@@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
@@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -184,7 +180,6 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation =
|
||||
Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set;
|
||||
ck_tile::ignore = Splitk;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
@@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
|
||||
@@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// Epilogue selection: set to true for chainer-based, false for standard
|
||||
// CShuffleEpilogue
|
||||
constexpr bool UseChainerEpilogue = true;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
UseChainerEpilogue,
|
||||
// Chainer-based epilogue
|
||||
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
|
||||
ck_tile::CShuffleEpilogueChainProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>,
|
||||
ck_tile::DefaultScheduleTag>>,
|
||||
// Standard CShuffleEpilogue
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>>;
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_d_fp16_example.inc"
|
||||
|
||||
@@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
const auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
args.template GetWeightByte<WeiDataType>(),
|
||||
s.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return InvokerResult{ave_time, args.k_batch};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
|
||||
WorkspaceDataType,
|
||||
WeiDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
|
||||
WorkspaceDataType,
|
||||
WeiDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(args.G_ * args.K_),
|
||||
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(args.G_ * args.K_),
|
||||
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
|
||||
auto input_tensors =
|
||||
ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
return InvokerResult{ave_time, kargs.k_batch};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Regular Convolution: Simple, no split-image
|
||||
// =====================================================================
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Split-K dispatch
|
||||
// =====================================================================
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Kernel launch lambda: Uses EnableSplitImage based on layout support
|
||||
// =====================================================================
|
||||
const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto enable_split_image_) {
|
||||
constexpr bool EnableSplitImage = enable_split_image_.value;
|
||||
|
||||
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
|
||||
@@ -255,7 +254,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
@@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
if(use_split_image)
|
||||
{
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<true>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<true>{});
|
||||
return Run(ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<false>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<false>{});
|
||||
return Run(ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -13,11 +13,6 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
|
||||
@@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
@@ -20,9 +20,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp
|
||||
gemm_quant_rowcol.cpp
|
||||
gemm_quant_tensor.cpp
|
||||
)
|
||||
|
||||
@@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory(
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -49,4 +49,10 @@ void bquant_quantgrouped_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -51,4 +51,10 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -51,4 +51,10 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -111,11 +111,29 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
@@ -144,9 +162,18 @@ int main(int argc, char* argv[])
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
|
||||
|
||||
@@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
using BQuantPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline =
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>>;
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
AQuantPipeline,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ABQuantPipeline,
|
||||
BQuantPipeline>>>;
|
||||
|
||||
constexpr bool TiledPermuteN =
|
||||
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
@@ -173,77 +181,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
|
||||
// Epilogue selection: use chainer for RowCol/Tensor quant, standard for others
|
||||
// Toggle to switch between chainer-based and standard CShuffleEpilogue
|
||||
constexpr bool UseChainerEpilogue = true;
|
||||
|
||||
// Define the schedule tag based on quant mode
|
||||
using ScheduleTag =
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
ck_tile::RowColQuantScheduleTag,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::TensorQuantScheduleTag,
|
||||
ck_tile::DefaultScheduleTag>>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant),
|
||||
// Chainer-based epilogue for RowCol/Tensor quant modes
|
||||
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
|
||||
ck_tile::CShuffleEpilogueChainProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>,
|
||||
ScheduleTag>>,
|
||||
// Standard CShuffleEpilogue for other modes
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>>;
|
||||
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
@@ -579,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
@@ -955,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
@@ -985,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant)
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
|
||||
@@ -48,112 +48,87 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation.value,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile =
|
||||
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the
|
||||
// same output tile in the C tensor, so we must
|
||||
// atomic add the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing
|
||||
// final results to each macro tile in the C
|
||||
// tensor, so we can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
@@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
|
||||
const auto Run = [&]() {
|
||||
constexpr auto memory_operation =
|
||||
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel =
|
||||
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel =
|
||||
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::GetBlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::GetBlockSize();
|
||||
if(!Kernel::IsSupportedArguments(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArguments(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
};
|
||||
|
||||
return Run();
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(G, M, N, K) \
|
||||
|
||||
Reference in New Issue
Block a user