mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Merge remote-tracking branch 'origin/develop' into samremes/ck_tile_mx_gemm
This commit is contained in:
@@ -47,7 +47,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,128,256
|
||||
--optdim 32,64,80,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
|
||||
@@ -24,11 +24,31 @@ from codegen.cpp_symbol_map import (
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8": 8,
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
@@ -52,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
@@ -62,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
@@ -85,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -98,8 +123,8 @@ using fmha_epilogue_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -108,7 +133,7 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
@@ -177,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -223,12 +248,15 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -315,6 +343,8 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -375,6 +405,8 @@ class FmhaFwdPipeline:
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
@@ -433,6 +465,13 @@ class FmhaFwdApiPool:
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -490,10 +529,12 @@ class FmhaFwdKernel:
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_kname=self.name,
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
@@ -526,17 +567,24 @@ class FmhaFwdKernel:
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -570,16 +618,23 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -589,20 +644,45 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
qscale = "no"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
qscale = "no"
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -612,7 +692,7 @@ class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
@@ -654,70 +734,75 @@ def get_fwd_blobs(
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -40,7 +40,16 @@ DTYPE_BITS = {
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
48: 48,
|
||||
64: 64,
|
||||
80: 96,
|
||||
96: 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
@@ -202,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream
|
||||
const bool can_dispatch_v3 =
|
||||
(device_name.compare(0, 6, "gfx950") == 0) and
|
||||
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
|
||||
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
|
||||
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
|
||||
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
|
||||
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
|
||||
(args.hdim_v == 128);
|
||||
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
|
||||
(not traits.has_lse) and (not traits.has_dropout) and
|
||||
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
|
||||
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
|
||||
if ({F_is_v3_enabled} and can_dispatch_v3) {{
|
||||
return fmha_fwd_v3(traits, args, config);
|
||||
}} else {{
|
||||
@@ -930,6 +938,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
@@ -1008,14 +1017,18 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
if hdim == 64:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
@@ -1068,9 +1081,9 @@ class KernelComponentFactoryGfx950(
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for mask in ["no", "causal"]:
|
||||
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
@@ -621,8 +621,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
|
||||
// p_lp_g_m_n low precision used for fwd (with rp_undrop)
|
||||
ck_tile::HostTensor<GemmDataType> p_fwd_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
// p_lp_g_m_n low precision used for bwd (no rp_undrop)
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
@@ -762,8 +765,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_dropout_randval(
|
||||
randval_host_ref, wb, drop_seed, drop_offset);
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, 1.f);
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_dropped_hp_host_ref.ForEach(
|
||||
[&](auto& self, const auto& idx) { self(idx) *= rp_undrop; });
|
||||
p_fwd_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
@@ -789,12 +795,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_fwd_host_ref = p_lp_host_ref;
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
p_fwd_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
@@ -900,7 +907,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
|
||||
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, 1.f);
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
@@ -911,7 +918,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
{
|
||||
do_dot_o +=
|
||||
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o)) *
|
||||
p_undrop;
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) =
|
||||
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
|
||||
@@ -935,7 +943,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
ck_tile::
|
||||
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
p_t_lp_host_ref,
|
||||
do_t_host_ref,
|
||||
dv_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(rp_undrop)); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
@@ -945,7 +958,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
ck_tile::scales(scale * rp_undrop)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
@@ -956,7 +969,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
ck_tile::scales(scale * rp_undrop)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
|
||||
@@ -230,6 +230,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -500,6 +504,9 @@ struct fmha_batch_prefill_args
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
@@ -516,6 +523,7 @@ struct fmha_batch_prefill_args
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -526,14 +534,25 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
// SGLang-style page table
|
||||
int32_t num_total_pages;
|
||||
void* kv_indptr;
|
||||
void* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
void* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#endif
|
||||
// KV cache page table fields (kv_lookup_table selects interpretation):
|
||||
// - SGLANG_PAGE_TABLE_1D:
|
||||
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
|
||||
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
|
||||
// kv_last_page_lens: per-batch last page lengths [batch]
|
||||
// - VLLM_BLOCK_TABLE_2D:
|
||||
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
|
||||
// batch_stride_block_table: row stride for block_table
|
||||
// seqlen_k_ptr: per-batch seqlen_k [batch]
|
||||
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
|
||||
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
|
||||
kv_memory_layout; // KV memory layout (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
|
||||
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
|
||||
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
|
||||
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
|
||||
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
|
||||
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
@@ -624,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -674,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -728,6 +749,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -758,6 +780,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -832,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
args.min_seqlen_q,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -877,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -944,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -992,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1108,6 +1135,22 @@ template <typename FmhaKernel>
|
||||
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
|
||||
const PageTableKargs page_table = [&]() {
|
||||
if constexpr(FmhaKernel::kKVLookupTable ==
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
|
||||
}
|
||||
else
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
args.batch_stride_block_table,
|
||||
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
|
||||
}
|
||||
}();
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
@@ -1116,6 +1159,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -1125,12 +1171,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1156,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1164,6 +1207,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -1173,12 +1219,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1209,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1270,6 +1313,65 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
BiasEnum_,
|
||||
kStoreLse_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
@@ -1516,7 +1618,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
using fmha_batch_prefill_traits = fmha_fwd_traits;
|
||||
struct fmha_batch_prefill_traits : public fmha_fwd_traits
|
||||
{
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
int page_size = 1;
|
||||
};
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits,
|
||||
fmha_batch_prefill_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename SMPLComputeDataType>
|
||||
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
||||
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
||||
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t real_seqlen_q,
|
||||
ck_tile::index_t real_seqlen_k)
|
||||
{
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
||||
}
|
||||
// Append sink token at the end of each row
|
||||
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
int init_sink_value,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
||||
bias_host);
|
||||
}
|
||||
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
||||
// for close to rowmax values.
|
||||
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
||||
sink_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
if(init_sink_value != 0)
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
else
|
||||
args.sink_ptr = nullptr;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -1351,8 +1387,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
||||
if(lse)
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
// Create extended tensor with sink token
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
// Copy original attention scores and append sink values
|
||||
copy_attention_scores_with_sink(
|
||||
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
||||
|
||||
// Compute softmax on extended tensor
|
||||
ck_tile::HostTensor<PDataType> p_extended(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func);
|
||||
}
|
||||
|
||||
// Extract only the original columns (exclude sink token column)
|
||||
p_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
// No sink tokens - compute softmax directly
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
|
||||
@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
@@ -69,107 +69,88 @@ struct BasicInvoker
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
args.stride_E);
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&]() {
|
||||
// use SET operation since each K-split writes to separate memory
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
return Run();
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -286,7 +277,6 @@ template <typename CDataType,
|
||||
typename ELayout = ck_tile::tensor_layout::gemm::RowMajor>
|
||||
float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
const ck_tile::index_t reduce_dim_size = args.k_batch; // Number of partial results to reduce
|
||||
// Calculate output size based on the final output tensor dimensions
|
||||
const ck_tile::index_t output_size = args.M * args.N;
|
||||
|
||||
@@ -303,27 +293,28 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
|
||||
constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Add;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
using BlockWarps = ck_tile::sequence<1, 1>;
|
||||
using BlockTile = ck_tile::sequence<256, 1>;
|
||||
using WarpTile = ck_tile::sequence<256, 1>;
|
||||
using ThreadTile = ck_tile::sequence<1, 1>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem =
|
||||
ck_tile::Reduce2dProblem<CDataType, ComputeDataType, CDataType, Shape, ReduceOp>;
|
||||
using Kernel = ck_tile::Reduce<Problem>;
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<CDataType,
|
||||
ComputeDataType,
|
||||
CDataType,
|
||||
Shape,
|
||||
ReduceOp,
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
3>;
|
||||
using Kernel = ck_tile::ReduceKernel<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Reduction arguments not supported!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 2 - Launching Reduction kernel" << '\n'
|
||||
@@ -343,9 +334,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
|
||||
static_cast<const CDataType*>(args.e_ptr), // workspace input
|
||||
static_cast<CDataType*>(args.final_output_ptr), // final output
|
||||
workspace_shape,
|
||||
workspace_strides,
|
||||
kept_dim,
|
||||
reduce_dims));
|
||||
workspace_strides));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -460,12 +460,6 @@ inline auto create_args()
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
// Type aliases for memory operation integral constants
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -57,114 +57,95 @@ struct WeightPreshuffleInvoker
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time =
|
||||
ck_tile::launch_kernel_time_mask(s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,112 +60,94 @@ struct UniversalInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Threadwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Blockwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "19", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = float;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
// Operations: one doing a sum reduction, the other computing the mean square
|
||||
// In the case of mean square:
|
||||
// 1. The element wise operation squares each element before reduction
|
||||
// 2. The reduction operation sum the squared element
|
||||
// 3. The accumulator element wise operation divides the result by the total number of reduced
|
||||
// elements (intra block operation)
|
||||
// 4. The partial result is updated across blocks using inter block reduction, a sum.
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
|
||||
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise
|
||||
// ops
|
||||
auto accumulator_elementwise_ops = ck_tile::make_tuple(
|
||||
ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
|
||||
auto inter_block_reduce_ops = ck_tile::make_tuple(
|
||||
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// Determine block group size for multi-block reduction
|
||||
// block_group_size records how many blocks participate to a reduction (input data dependent)
|
||||
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
|
||||
// to process the whole reduction, each thread will to process multiple thread tile
|
||||
// a num_block_tile_iterations times
|
||||
auto [num_block_tile_iterations, block_group_size] =
|
||||
typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams();
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
ck_tile::index_t kGridSize =
|
||||
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;
|
||||
|
||||
std::cout << "Block group size: " << block_group_size
|
||||
<< ", Num block tile iterations: " << num_block_tile_iterations
|
||||
<< ", Reduce total length: " << reduce_total_length << std::endl;
|
||||
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
// Init the output data with identity values respective to each reduce op
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
constexpr auto op = reduce_ops.at(i);
|
||||
const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
const auto output_number_elements = N * C;
|
||||
std::fill(h.begin() + i * output_number_elements,
|
||||
h.begin() + (i + 1) * output_number_elements,
|
||||
identity_val);
|
||||
});
|
||||
|
||||
auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };
|
||||
|
||||
float ave_time = launch_kernel_time_mask(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
clear_output_buffer,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops)
|
||||
|
||||
);
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops,
|
||||
block_group_size);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
std::cout << "Checking operation " << i << ": " << std::endl;
|
||||
|
||||
bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
|
||||
if(pass_op)
|
||||
{
|
||||
std::cout << "✅ valid results for this operation" << std::endl;
|
||||
}
|
||||
pass &= pass_op;
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "7", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
// Two operations: one do a sum reduction, the other computing the mean square
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops
|
||||
auto elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise ops
|
||||
auto accumulator_elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementiwise ops on reduction,
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceThreadWise<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
@@ -9,14 +9,14 @@
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "7", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
arg_parser.insert("n", "16", "n dimension")
|
||||
.insert("h", "64", "h dimension")
|
||||
.insert("w", "32", "w dimension")
|
||||
.insert("c", "960", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("warmup", "20", "cold iter")
|
||||
.insert("repeat", "100", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "reduce.json", "json file name to dump results");
|
||||
|
||||
@@ -47,12 +47,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
constexpr auto kept_dim = ck_tile::sequence<1, 2, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({H, W, C}, {W * C, C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({H, W, C}, {W * C, C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
@@ -62,40 +62,40 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Add;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using Vector = ck_tile::sequence<8, 8>;
|
||||
using BlockWarps = ck_tile::sequence<1, 1>;
|
||||
using BlockTile = ck_tile::sequence<256, 1>;
|
||||
using WarpTile = ck_tile::sequence<256, 1>;
|
||||
using ThreadTile = ck_tile::sequence<1, 1>;
|
||||
|
||||
// cross warp-reduce
|
||||
// using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
// using BlockTile = ck_tile::sequence<2, 1024>;
|
||||
// using WarpTile = ck_tile::sequence<1, 512>;
|
||||
// using Vector = ck_tile::sequence<1, 8>;
|
||||
// using ThreadTile = ck_tile::sequence<1, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
ck_tile::index_t kept_dim_len_prod = H * W * C;
|
||||
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem =
|
||||
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Porblem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
ReduceOp,
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::Reduce<Porblem>;
|
||||
using Kernel = ck_tile::ReduceKernel<Porblem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
@@ -105,11 +105,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims));
|
||||
input_strides));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
std::size_t num_btype = sizeof(XDataType) * N * H * W * C + sizeof(YDataType) * H * W * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
@@ -149,8 +147,8 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
// else if(data_type == "bf16")
|
||||
// {
|
||||
// return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
// }
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
printf("moe_buf:%" PRIu64 "(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
printf("moe_buf:%" PRIu64 ", ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
@@ -3,7 +3,18 @@
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_executable(tile_example_grouped_gemm grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm quant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm
|
||||
quant_grouped_gemm.cpp
|
||||
quant_grouped_gemm_fp8_aquant.cpp
|
||||
quant_grouped_gemm_fp8_bquant.cpp
|
||||
quant_grouped_gemm_fp8_rowcol.cpp
|
||||
quant_grouped_gemm_fp8_tensor.cpp
|
||||
quant_grouped_gemm_bf8_aquant.cpp
|
||||
quant_grouped_gemm_bf8_bquant.cpp
|
||||
quant_grouped_gemm_bf8_rowcol.cpp
|
||||
quant_grouped_gemm_bf8_tensor.cpp
|
||||
)
|
||||
add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp)
|
||||
add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
@@ -14,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "abquant_grouped_gemm.hpp"
|
||||
|
||||
// Non-persistent grouped gemm for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
// Persistent grouped gemm tileloop for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC>;
|
||||
|
||||
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_abquant_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = run_abquant_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <bool Persistent_>
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
// ABQuant specialization for GemmQuantConfig
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.")
|
||||
.insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
// Forward declaration of the non-persistent version
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
|
||||
// Forward declaration of the tileloop version for persistent kernels
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
@@ -62,71 +62,55 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -139,8 +123,7 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(!splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ave_time =
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -328,5 +328,4 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
void* kargs_ptr);
|
||||
|
||||
@@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -142,8 +126,7 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
if(!splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_multi_d_example.inc"
|
||||
|
||||
@@ -65,70 +65,54 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -141,8 +125,7 @@ template <typename GemmConfig,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -3,332 +3,128 @@
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
extern template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
extern template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "quant_run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = run_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
if(data_type == "bf8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
arg_parser, a_layout, b_layout, persistent);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -64,8 +64,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
@@ -152,57 +152,7 @@ struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "quant_run_grouped_gemm_example.hpp"
|
||||
|
||||
template int run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
const ck_tile::ArgParser&, std::string, std::string, bool);
|
||||
@@ -0,0 +1,300 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
@@ -3,6 +3,24 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include "quant_grouped_gemm_config.hpp"
|
||||
#include "quant_invoke_grouped_gemm_kernel.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -11,9 +29,9 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
static auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
@@ -170,21 +188,13 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
int run_grouped_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
@@ -540,7 +550,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser,
|
||||
std::string a_layout,
|
||||
std::string b_layout)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -556,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
@@ -566,102 +577,72 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
(QuantMode == ck_tile::QuantType::BQuantGrouped && !GemmConfig::PreshuffleB))
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Col{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_persistency(
|
||||
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
|
||||
int run_gemm_example_persistency(const ck_tile::ArgParser& arg_parser,
|
||||
std::string a_layout,
|
||||
std::string b_layout,
|
||||
bool persistent)
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
arg_parser, a_layout, b_layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
if(data_type == "bf8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
arg_parser, a_layout, b_layout);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,604 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_abquant_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
ave_time = grouped_gemm_abquant<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_abquant_grouped_gemm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
|
||||
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
|
||||
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
if(!valid_input_data(
|
||||
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
|
||||
// Clear existing (invalid) data before adding defaults
|
||||
Ms.clear();
|
||||
Ns.clear();
|
||||
Ks.clear();
|
||||
stride_As.clear();
|
||||
stride_Bs.clear();
|
||||
stride_Cs.clear();
|
||||
stride_AQs.clear();
|
||||
stride_BQs.clear();
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
// Let get_default_stride calculate based on layout
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
|
||||
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
aq_tensors.reserve(group_count);
|
||||
bq_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
aq_dev_buf.reserve(group_count);
|
||||
bq_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
// For ABQuantGrouped, both A and B need quantization
|
||||
static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
"This file only supports ABQuantGrouped mode");
|
||||
|
||||
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize
|
||||
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
aq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_aq,
|
||||
p_bq,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stride_Cs[i],
|
||||
stride_AQs[i],
|
||||
stride_BQs[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_abquant_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
|
||||
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
|
||||
sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Reference implementation for ABQuantGrouped
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &=
|
||||
ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_grouped_gemm_json_results<ALayout, BLayout, CLayout>(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
group_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig, typename BQuantGroupSize>
|
||||
int run_abquant_grouped_gemm_example_prec_type_with_bquant(
|
||||
std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig>
|
||||
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(bquant_group_size == "1x1x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else if(bquant_group_size == "1x128x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType>
|
||||
int run_abquant_gemm_example_persistency(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
bool persistent,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, true>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, false>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string c_layout = arg_parser.get_str("c_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
const std::string bquant_group_size = arg_parser.get_str("bquant_group_size");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
@@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup,
|
||||
// earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
|
||||
@@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType>(stream, group_count, kargs_ptr, splitk);
|
||||
CDataType>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup,
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::GemmTransKernelArg<NumDTensor>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr},
|
||||
@@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup,
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time =
|
||||
grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
|
||||
ave_time = grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ if(has_supported_gpu)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
|
||||
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
@@ -30,13 +31,14 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
|
||||
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
include(mxgemm/mx_flatmm_instance.cmake)
|
||||
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
|
||||
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
|
||||
|
||||
@@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -207,7 +204,6 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -150,7 +147,6 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
@@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
|
||||
@@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
@@ -159,7 +156,6 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -191,13 +187,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
|
||||
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
<< "\n"
|
||||
<< "k_batch: " << kargs.k_batch << std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
@@ -263,23 +261,7 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -471,10 +453,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_split_k")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
"[gemm1_gate_up | gemm1_split_k | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
@@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
@@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
|
||||
.insert("k_batch", "1", "parallism to control splik-k.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
@@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
@@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType =
|
||||
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
@@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
@@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
|
||||
N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
@@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -184,7 +180,6 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
@@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -25,14 +25,16 @@ using BF16 = ck_tile::bf16_t;
|
||||
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
||||
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
||||
|
||||
inline constexpr int ScaleGranularityM = 1;
|
||||
inline constexpr int ScaleGranularityN = 1;
|
||||
inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
template float mx_flatmm_calc<FLATMM_CONFIG,
|
||||
A_DATA_TYPE,
|
||||
|
||||
@@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation =
|
||||
Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set;
|
||||
ck_tile::ignore = Splitk;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
@@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
|
||||
@@ -105,10 +105,12 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto scale_a_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -81,60 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_d_fp16_example.inc"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
|
||||
@@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
const auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
args.template GetWeightByte<WeiDataType>(),
|
||||
s.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return InvokerResult{ave_time, args.k_batch};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -21,6 +21,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
// Force Vector Size C to 1 for two stage to check main
|
||||
// two stage use case
|
||||
constexpr ck_tile::index_t VectorSizeC = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -39,7 +42,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
OutLayout,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
@@ -62,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
|
||||
WorkspaceDataType,
|
||||
WeiDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
|
||||
WorkspaceDataType,
|
||||
WeiDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(args.G_ * args.K_),
|
||||
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(args.G_ * args.K_),
|
||||
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
|
||||
auto input_tensors =
|
||||
ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
return InvokerResult{ave_time, kargs.k_batch};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Regular Convolution: Simple, no split-image
|
||||
// =====================================================================
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Split-K dispatch
|
||||
// =====================================================================
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Kernel launch lambda: Uses EnableSplitImage based on layout support
|
||||
// =====================================================================
|
||||
const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto enable_split_image_) {
|
||||
constexpr bool EnableSplitImage = enable_split_image_.value;
|
||||
|
||||
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
|
||||
@@ -255,7 +254,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
@@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
if(use_split_image)
|
||||
{
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<true>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<true>{});
|
||||
return Run(ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<false>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<false>{});
|
||||
return Run(ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -13,11 +13,6 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
|
||||
@@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
@@ -12,6 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
set(EXE_NAME tile_example_gemm_quant)
|
||||
add_executable(${EXE_NAME}
|
||||
gemm_quant.cpp
|
||||
gemm_abquant_quantgrouped.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
@@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
|
||||
@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"bf8i4 or bf16fp4")
|
||||
"or bf8i4; for ABQuant: fp8, bf8")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, abquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
.insert("group_size",
|
||||
@@ -75,6 +75,16 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
|
||||
params.push_back(preshufflequant);
|
||||
}
|
||||
if(quant_mode == "abquant")
|
||||
{
|
||||
std::string preshuffleb =
|
||||
arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb";
|
||||
params.push_back(preshuffleb);
|
||||
|
||||
std::string preshufflequant =
|
||||
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
|
||||
params.push_back(preshufflequant);
|
||||
}
|
||||
if(quant_mode != "rowcol" && quant_mode != "tensor")
|
||||
{
|
||||
// NOTE: rowcol and tensor pipeline do not use group size
|
||||
@@ -85,6 +95,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
return hash_multiple_strings(params);
|
||||
}
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
@@ -124,6 +136,7 @@ int main(int argc, char* argv[])
|
||||
ck_tile::hip_check_error(hipSetDevice(device_id));
|
||||
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
abquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -25,7 +26,8 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename QuantGroupSize,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
@@ -72,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -87,7 +90,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B quants use their own
|
||||
// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
@@ -102,54 +105,81 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
AQuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
BQuantGroupSize,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmABQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType, // For AQ
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::QDataType, // For BQ
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
using BQuantPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline =
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>;
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
AQuantPipeline,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ABQuantPipeline,
|
||||
BQuantPipeline>>>;
|
||||
|
||||
constexpr bool TiledPermuteN =
|
||||
(QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN);
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
@@ -171,7 +201,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
@@ -264,7 +293,8 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename QuantGroupSize,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
@@ -277,6 +307,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t AQK,
|
||||
ck_tile::index_t BQK,
|
||||
ck_tile::index_t BQN,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
@@ -313,7 +344,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
@@ -330,7 +362,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
}
|
||||
if(bq_dev_buf != nullptr)
|
||||
{
|
||||
num_byte += sizeof(typename TypeConfig::QDataType) * N * BQK;
|
||||
num_byte += sizeof(typename TypeConfig::QDataType) * BQN * BQK;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -338,10 +370,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
|
||||
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
|
||||
<< " AQ_Layout =" << AQLayout::name << " BQ_Layout =" << BQLayout::name;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
<< " StrideBQ =" << stride_BQ << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " AQ_Layout =" << AQLayout::name
|
||||
<< " BQ_Layout =" << BQLayout::name;
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
std::cout << " StrideBQ =" << stride_BQ;
|
||||
@@ -366,7 +401,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename QuantGroupSize,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -395,21 +431,30 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = ck_tile::integer_divide_ceil(
|
||||
K, QuantGroupSize::kK); // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
K, AQuantGroupSize::kK); // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = ck_tile::integer_divide_ceil(
|
||||
K, QuantGroupSize::kK); // Group quantization: BQK = K / GroupSize
|
||||
BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
|
||||
K, BQuantGroupSize::kK); // Group quantization: BQK = K / GroupSize
|
||||
BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
AQK = ck_tile::integer_divide_ceil(
|
||||
K, AQuantGroupSize::kK); // Group quantization: AQK = K / GroupSize
|
||||
BQK = ck_tile::integer_divide_ceil(
|
||||
K, BQuantGroupSize::kK); // Group quantization: BQK = K / GroupSize
|
||||
BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
BQN = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -419,9 +464,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
@@ -449,6 +493,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
stride_AQ = 0; // No A quantization
|
||||
stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
|
||||
@@ -473,6 +522,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
// Create AQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
|
||||
@@ -492,6 +542,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
@@ -543,6 +598,25 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
@@ -566,6 +640,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
@@ -591,6 +672,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
@@ -599,6 +681,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
@@ -607,13 +690,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -637,7 +721,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PreshuffleB)
|
||||
{
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1)
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN && BQuantGroupSize::kN == 1)
|
||||
{
|
||||
printf("PreshuffleB with TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
@@ -659,19 +743,20 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
|
||||
QuantGroupSize::kN == 1)
|
||||
BQuantGroupSize::kN == 1)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, QuantGroupSize::kN);
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
|
||||
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -682,7 +767,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -698,7 +783,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(a_m_k_dev_buf,
|
||||
aq_dev_buf_ptr.get(),
|
||||
b_k_n_dev_buf,
|
||||
@@ -709,6 +795,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
BQN,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
@@ -736,7 +823,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
true>(a_m_k, *aq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
@@ -747,7 +834,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
false>(
|
||||
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
@@ -756,9 +843,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
@@ -806,10 +905,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Usage of Two-Matrix Quantization (AB-Quant)
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename QuantGroupSize,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -835,17 +935,24 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
@@ -853,24 +960,24 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
@@ -883,3 +990,16 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
return 0;
|
||||
}
|
||||
// Support for Unilateral Quantization (A/B)
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
QuantGroupSize,
|
||||
QuantMode>(arg_parser);
|
||||
}
|
||||
|
||||
@@ -48,112 +48,87 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation.value,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile =
|
||||
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the
|
||||
// same output tile in the C tensor, so we must
|
||||
// atomic add the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing
|
||||
// final results to each macro tile in the C
|
||||
// tensor, so we can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
@@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
|
||||
const auto Run = [&]() {
|
||||
constexpr auto memory_operation =
|
||||
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel =
|
||||
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using Kernel =
|
||||
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::GetBlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::GetBlockSize();
|
||||
if(!Kernel::IsSupportedArguments(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArguments(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
};
|
||||
|
||||
return Run();
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(G, M, N, K) \
|
||||
|
||||
Reference in New Issue
Block a user