mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix and merge
This commit is contained in:
@@ -25,7 +25,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
@@ -34,7 +34,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -57,7 +57,7 @@ add_custom_command(
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_FWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
@@ -65,7 +65,7 @@ target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_BWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
@@ -71,6 +71,7 @@ args:
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
@@ -117,8 +117,50 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cu) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cu = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
@@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -220,17 +276,18 @@ class FmhaFwdApiTrait:
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -297,8 +354,8 @@ class FmhaFwdApiPool:
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
@@ -313,25 +370,27 @@ class FmhaFwdApiPool:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
@@ -423,33 +482,21 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -458,54 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
tiles = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
|
||||
@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
@@ -275,31 +282,32 @@ class FmhaFwdApiPool:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -381,6 +389,7 @@ class FmhaFwdKernel:
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -419,25 +428,28 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -445,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -453,36 +465,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -498,17 +510,15 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for pipeline in get_pipelines(dtype, hdim, hdim_v):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
if (hdim, hdim_v) == (192, 128) or hdim == 160:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
@@ -532,6 +542,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -540,6 +551,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -565,6 +577,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
|
||||
47
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
47
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -178,50 +178,30 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
max_splits = std::min({max_splits, num_SMs});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
@@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
(void)hdim_v;
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
@@ -250,15 +231,13 @@ int override_num_splits_if_necessary(
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
@@ -542,8 +521,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
|
||||
@@ -169,6 +169,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
@@ -713,102 +715,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
@@ -21,6 +21,8 @@ enum class mask_enum
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
@@ -42,6 +44,8 @@ struct mask_info
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
@@ -148,7 +152,22 @@ struct mask_info
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
ck_tile::index_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return seqlen_q * seqlen_k;
|
||||
ck_tile::index_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
|
||||
@@ -25,7 +25,7 @@ add_custom_command(
|
||||
|
||||
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
|
||||
|
||||
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ args:
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
|
||||
@@ -12,15 +12,23 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -50,8 +58,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
@@ -60,9 +70,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -128,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -144,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -199,19 +212,39 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "i8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::int8_t, ck_tile::int8_t, int32_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -14,99 +13,28 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_ASYNC 4
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_ASYNC
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompAsync
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompAsync
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
struct GemmConfig
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
@@ -120,6 +48,169 @@ struct GemmConfig
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
@@ -171,6 +262,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -186,6 +286,12 @@ struct DataTypeTraits<double>
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
@@ -216,6 +322,51 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -234,11 +385,23 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent = false,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Tensor,
|
||||
template <typename GemmConfig,
|
||||
typename Tensor,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -63,11 +64,12 @@ void permute_tensor_b(Tensor& tensor)
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GEMM_PIPELINE_SCHEDULER,
|
||||
GemmConfig::Scheduler,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
@@ -144,13 +146,31 @@ void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -162,23 +182,55 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
true,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -193,13 +245,14 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
@@ -229,6 +282,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -243,8 +297,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -278,7 +332,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<decltype(b_k_n_dev),
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -304,19 +359,28 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
@@ -351,29 +415,19 @@ int run_gemm_example_with_layouts(int argc,
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -383,16 +437,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
|
||||
@@ -11,28 +11,22 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
{
|
||||
if constexpr(Pipeline::PrefetchStages > static_cast<int>(TN))
|
||||
{
|
||||
if(tn == TN)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, TN>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -41,30 +35,36 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
@@ -74,64 +74,118 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -150,103 +204,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4 || \
|
||||
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -256,14 +221,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
@@ -272,26 +237,26 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
// if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Row{}, Row{}, Row{});
|
||||
// }
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "C" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Row{}, Row{});
|
||||
// }
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
@@ -299,7 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
@@ -311,31 +276,50 @@ int run_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
@@ -346,7 +330,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
run_gemm_example(argc, argv);
|
||||
return !run_gemm_example<GemmConfigComputeV3>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(EXAMPLE_REDUCE "tile_example_reduce")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
message(DEBUG "adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
|
||||
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -35,7 +35,7 @@ struct Reduce2dShape
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
|
||||
@@ -25,7 +25,7 @@ add_custom_command(
|
||||
|
||||
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
|
||||
|
||||
message("adding ${TILE_RMSNORM2D_FWD}")
|
||||
message(DEBUG "adding ${TILE_RMSNORM2D_FWD}")
|
||||
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
|
||||
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -74,22 +74,22 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -712,4 +712,4 @@ if __name__ == "__main__":
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
gen_blobs(args)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -67,13 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
@@ -184,6 +185,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
|
||||
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
@@ -191,8 +193,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
@@ -80,22 +80,23 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -103,13 +104,13 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
function (add_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
|
||||
@@ -49,22 +49,22 @@ struct smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -72,13 +72,13 @@ struct smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,14 +14,24 @@ This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
-v turn CPU validation on (1) or off (0). (default:1)
|
||||
-pr_i index data type. Only int32 is currently supported. (default:int32)
|
||||
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e number of num_experts (default:8)
|
||||
-k topk (default:4)
|
||||
-unit unit_size (default:32)
|
||||
-moe_buf_size moe_buf_size (default:0)
|
||||
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
|
||||
please make sure eid is in ascending order!
|
||||
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
|
||||
-kname prints the kernel name when set to 1 (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
|
||||
```
|
||||
|
||||
@@ -18,10 +18,20 @@
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
|
||||
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
|
||||
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
@@ -30,8 +40,11 @@ auto create_args(int argc, char* argv[])
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
"please make sure eid is in ascending order!")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("seed",
|
||||
"-1",
|
||||
"seed to be used. When set to -1, a random seed will be generated each time "
|
||||
"invoking this example")
|
||||
.insert("kname", "0", "prints the kernel name when set to 1")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -70,6 +83,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int local_tokens = args.get_int("local_t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
@@ -95,6 +109,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
return false;
|
||||
}
|
||||
|
||||
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
|
||||
// case
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool local_expert_masking = args.get_str("local_eid") != "-1";
|
||||
auto local_expert_masking_host = [&]() {
|
||||
if(local_expert_masking)
|
||||
@@ -143,6 +167,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::DeviceMem local_expert_masking_dev(
|
||||
local_expert_masking_host.get_element_space_size_in_bytes());
|
||||
|
||||
// used for simulating dynamic_tokens for EP case
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
@@ -164,6 +195,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
@@ -236,13 +268,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
workspace_size != 0 ? 1 : 0);
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
}
|
||||
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
@@ -285,6 +316,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size,
|
||||
is_local_token ? local_tokens
|
||||
: tokens,
|
||||
local_expert_masking);
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
@@ -334,16 +367,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
try
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec == "fp32" && index_prec == "int32")
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -171,6 +185,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -179,15 +194,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -195,15 +212,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -211,15 +230,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -227,15 +248,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -244,15 +267,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -261,28 +286,53 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
|
||||
@@ -31,4 +31,14 @@ $EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -moe_buf_size=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
|
||||
@@ -38,22 +38,22 @@ struct moe_smoothquant_traits_
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -61,13 +61,13 @@ struct moe_smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
message(DEBUG "adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -16,6 +16,7 @@ struct fused_moe_args
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
const void* local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
|
||||
@@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
@@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
}
|
||||
|
||||
@@ -87,7 +87,18 @@ void topid_unique_gen(
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
arg_parser
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
@@ -131,6 +142,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
@@ -169,6 +181,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
@@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
std::cout << "(" << local_tokens << ")";
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", act:"
|
||||
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
|
||||
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
@@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
@@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
@@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
if(activation == 0)
|
||||
{
|
||||
@@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
@@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
@@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
@@ -15,7 +15,16 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "batched_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
@@ -123,12 +132,16 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -139,6 +152,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -183,141 +197,22 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
|
||||
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
|
||||
"got "
|
||||
<< tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_batched_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -23,7 +23,16 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -44,20 +53,29 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::BatchedGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_E = stride_C;
|
||||
args.batch_stride_A = batch_stride_A;
|
||||
args.batch_stride_B = batch_stride_B;
|
||||
args.batch_stride_C = batch_stride_C;
|
||||
args.batch_stride_E = batch_stride_C;
|
||||
args.batch_count = batch_count;
|
||||
|
||||
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
|
||||
float ave_time = batched_gemm<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Batched Gemm"};
|
||||
@@ -169,22 +187,30 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_batched_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
|
||||
add_executable(tile_example_grouped_gemm_tileloop EXCLUDE_FROM_ALL grouped_gemm_tileloop.cpp)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Grouped CShuffle GEMM
|
||||
|
||||
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
|
||||
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
|
||||
@@ -16,15 +16,19 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -114,70 +118,79 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
@@ -196,125 +209,23 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
|
||||
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
|
||||
<< "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
constexpr bool Persistent = false;
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_grouped_gemm_example<Persistent>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
@@ -53,7 +54,7 @@ using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -70,14 +71,35 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.");
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_);
|
||||
void* kargs_ptr);
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
|
||||
177
example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp
Normal file
177
example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(!splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = true;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
@@ -30,20 +30,81 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
float ave_time = 0;
|
||||
if constexpr(!Persistent)
|
||||
{
|
||||
// Regular version of grouped gemm
|
||||
ave_time = grouped_gemm<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
{},
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
{},
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
@@ -66,7 +127,7 @@ float invoke_gemm(int n_warmup,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <bool Persistent, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -87,6 +148,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
@@ -102,7 +172,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
@@ -150,8 +220,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -169,13 +239,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO Add support for kbatch > 1 in grouped gemm
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
invoke_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
Persistent>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
@@ -183,7 +260,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
@@ -194,7 +271,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
@@ -211,6 +288,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <bool Persistent>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -227,12 +305,20 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
|
||||
@@ -3,6 +3,4 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -11,81 +11,60 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 2;
|
||||
|
||||
// This part comes from the Codegen
|
||||
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 8;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
using CodegenFlatmmShape =
|
||||
ck_tile::TileFlatmmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using CodegenFlatmmShape = ck_tile::TileFlatmmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenFlatmmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
@@ -109,15 +88,57 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName()
|
||||
<< CodegenPipelineProblem::GetName() << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time{0};
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
@@ -132,8 +153,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -146,24 +166,27 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -177,4 +200,35 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,63 @@
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
// GEMM config with 32x132 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
@@ -103,10 +159,10 @@ struct DataTypeTraits<ck_tile::half_t>
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -126,11 +182,22 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -32,38 +32,20 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type)
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
return t;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
@@ -91,6 +73,7 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
@@ -120,9 +103,15 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time = flatmm_calc<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
FlatmmConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -138,7 +127,11 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -162,9 +155,10 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -178,8 +172,26 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
// TODO: add different init types
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_host.SetZero();
|
||||
b_origin_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
@@ -188,29 +200,29 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
|
||||
// do pre-shuffle
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
|
||||
ck_tile::index_t mfma_type = 1;
|
||||
#else
|
||||
ck_tile::index_t mfma_type = 0;
|
||||
#endif
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type);
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_flatmm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
FlatmmConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
|
||||
6
example/ck_tile/19_gemm_multi_d/CMakeLists.txt
Normal file
6
example/ck_tile/19_gemm_multi_d/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_gemm_multi_d_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
35
example/ck_tile/19_gemm_multi_d/README.md
Normal file
35
example/ck_tile/19_gemm_multi_d/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
#Multiple D GEMM
|
||||
|
||||
This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_d_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-a_layout Tensor A layout (default:R)
|
||||
-b_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_a Tensor A strides - (Default: 0)
|
||||
-stride_b Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default 1)
|
||||
```
|
||||
296
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp
Normal file
296
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp
Normal file
@@ -0,0 +1,296 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_d_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_multiple_d_gemm_example(argc, argv); }
|
||||
79
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp
Normal file
79
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
using gemm_multi_d_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise>
|
||||
float gemm_multi_d(const gemm_multi_d_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,247 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
|
||||
const void* b_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t StrideA,
|
||||
ck_tile::index_t StrideB,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
gemm_multi_d_kargs gemm_descs({a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
ds_m_n_dev_buf,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
float ave_time = gemm_multi_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm Multiple-D"};
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
});
|
||||
|
||||
num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-D kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
|
||||
<< "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_multiple_d_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
using CDElementWiseFn = MultiplyMultiply;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideD0 = StrideD;
|
||||
ck_tile::index_t StrideD1 = StrideD;
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA = get_default_stride(M, K, StrideA, is_row_major(a_layout));
|
||||
StrideB = get_default_stride(K, N, StrideB, is_row_major(b_layout));
|
||||
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
|
||||
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
|
||||
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data());
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_gemm_multi_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
CDElementWiseFn>(
|
||||
a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref);
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v"))
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
pass &= ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
int run_multiple_d_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string ds_layout = arg_parser.get_str("ds_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
|
||||
{
|
||||
return run_multiple_d_gemm_example_with_layouts(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
|
||||
}
|
||||
}
|
||||
50
example/ck_tile/19_gemm_multi_d/utils.hpp
Normal file
50
example/ck_tile/19_gemm_multi_d/utils.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
4
example/ck_tile/20_grouped_convolution/CMakeLists.txt
Normal file
4
example/ck_tile/20_grouped_convolution/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
@@ -0,0 +1,207 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("weight_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }
|
||||
@@ -0,0 +1,108 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& image_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& strides,
|
||||
std::vector<ck_tile::index_t>& dilations,
|
||||
std::vector<ck_tile::index_t>& lpads,
|
||||
std::vector<ck_tile::index_t>& rpads,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
constexpr ck_tile::index_t non_sp_dims = 3;
|
||||
const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims;
|
||||
|
||||
if(!(n_dim_sp >= 1 && n_dim_sp <= 3))
|
||||
{
|
||||
throw std::runtime_error("Wrong layout!\n");
|
||||
}
|
||||
|
||||
if(n_dim_sp == 3)
|
||||
{
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("z"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("d"));
|
||||
strides.push_back(arg_parser.get_int("stride_d"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_d"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_d"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_d"));
|
||||
}
|
||||
if(n_dim_sp >= 2)
|
||||
{
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("y"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("h"));
|
||||
strides.push_back(arg_parser.get_int("stride_h"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_h"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_h"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_h"));
|
||||
}
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("x"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("w"));
|
||||
strides.push_back(arg_parser.get_int("stride_w"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_w"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_w"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_w"));
|
||||
|
||||
return n_dim_sp;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("g", "2", "group dimension")
|
||||
.insert("n", "32", "n dimension")
|
||||
.insert("k", "32", "k dimension")
|
||||
.insert("c", "32", "c dimension")
|
||||
|
||||
.insert("d", "64", "d dimension")
|
||||
.insert("h", "64", "h dimension")
|
||||
.insert("w", "64", "w dimension")
|
||||
|
||||
.insert("z", "4", "z dimension")
|
||||
.insert("y", "4", "y dimension")
|
||||
.insert("x", "4", "x dimension")
|
||||
|
||||
.insert("stride_d", "1", "d stride")
|
||||
.insert("stride_h", "1", "h stride")
|
||||
.insert("stride_w", "1", "w stride")
|
||||
|
||||
.insert("dilation_d", "1", "d dilation")
|
||||
.insert("dilation_h", "1", "h dilation")
|
||||
.insert("dilation_w", "1", "w dilation")
|
||||
|
||||
.insert("lpad_d", "0", "left pad for d dimension")
|
||||
.insert("lpad_h", "0", "left pad for h dimension")
|
||||
.insert("lpad_w", "0", "left pad for w dimension")
|
||||
|
||||
.insert("rpad_d", "0", "right pad for d dimension")
|
||||
.insert("rpad_h", "0", "right pad for h dimension")
|
||||
.insert("rpad_w", "0", "right pad for w dimension")
|
||||
|
||||
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
|
||||
.insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_fwd_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
weight.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
output_dev_buf.FromDevice(output.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_grouped_conv_fwd<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output_host_ref,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(output,
|
||||
output_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE=./build/bin/tile_example_batched_transpose
|
||||
|
||||
for pr in "fp8" "fp16" "bf16"; do
|
||||
@@ -8,4 +11,4 @@ $EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC'
|
||||
$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC'
|
||||
$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC'
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
#
|
||||
# in order to run this script you'd first need to build the tile_example_batched_transpose executables in ../build/bin/
|
||||
#
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE=./build/bin/tile_example_batched_transpose
|
||||
|
||||
for pr in "fp8" "fp16" "bf16"; do
|
||||
@@ -24,4 +27,4 @@ $EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW'
|
||||
$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC'
|
||||
$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW'
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
9
example/ck_tile/37_transpose/CMakeLists.txt
Normal file
9
example/ck_tile/37_transpose/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
set(TARGET_NAME tile_example_transpose)
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
|
||||
|
||||
27
example/ck_tile/37_transpose/README.md
Normal file
27
example/ck_tile/37_transpose/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Batched Transpose
|
||||
This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the transpose executable
|
||||
make tile_example_transpose -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_transpose`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-N input batch size (default:2)
|
||||
-C input channel size. (default:64)
|
||||
-H input height size. (default:1)
|
||||
-W input width size. (default:64)
|
||||
-v whether do CPU validation or not (default: 1)
|
||||
-layout_in input tensor data layout - NCHW by default
|
||||
-layout_out output tensor data layout - NHWC by default
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-k_name t to 1 will print kernel name (default:0)
|
||||
```
|
||||
120
example/ck_tile/37_transpose/batched_transpose_kernel.hpp
Normal file
120
example/ck_tile/37_transpose/batched_transpose_kernel.hpp
Normal file
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeHostArgs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
// index_t dim_blocks;
|
||||
index_t dim_stride;
|
||||
index_t dim_block_h;
|
||||
index_t dim_block_w;
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct BatchedTransposeKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
index_t dim_stride;
|
||||
};
|
||||
|
||||
using Kargs = BatchedTransposeKargs;
|
||||
using Hargs = BatchedTransposeHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
size_t grid_size_x = h.dim_block_w;
|
||||
size_t grid_size_y = h.dim_block_h;
|
||||
size_t grid_size_z = h.batch;
|
||||
return dim3(grid_size_x, grid_size_y, grid_size_z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.p_output = h.p_output;
|
||||
k.batch = h.batch;
|
||||
k.height = h.height;
|
||||
k.width = h.width;
|
||||
k.dim_stride = h.dim_stride;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock;
|
||||
|
||||
const auto iDim = blockIdx.z;
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.height, kargs.width),
|
||||
make_tuple(kargs.width, 1),
|
||||
number<Pipeline::GetVectorSize()>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock);
|
||||
|
||||
const auto y_n_m = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.width, kargs.height),
|
||||
make_tuple(kargs.height, 1),
|
||||
number<Pipeline::GetVectorSize()>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
|
||||
|
||||
Pipeline{}(x_block_window, y_block_window, smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
149
example/ck_tile/37_transpose/block_transpose.hpp
Normal file
149
example/ck_tile/37_transpose/block_transpose.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "transpose_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Layout_, index_t kRow, index_t kCol>
|
||||
struct TransposeTraits
|
||||
{
|
||||
static constexpr index_t kLeadDim = kCol;
|
||||
static constexpr index_t kSecondDim = kRow;
|
||||
};
|
||||
|
||||
template <index_t kRow, index_t kCol>
|
||||
struct TransposeTraits<tensor_layout::gemm::ColumnMajor, kRow, kCol>
|
||||
{
|
||||
static constexpr index_t kLeadDim = kRow;
|
||||
static constexpr index_t kSecondDim = kCol;
|
||||
};
|
||||
|
||||
// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the
|
||||
// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the
|
||||
// global memory
|
||||
template <typename DataType_,
|
||||
typename Layout_,
|
||||
index_t kBlockSize_,
|
||||
index_t kRowWarps_, // how many warps in row direction
|
||||
index_t kColWarps_, // how many warps in col direction
|
||||
index_t kRowPerBlock_, // row number per block
|
||||
index_t kColPerBlock_, // col number per block
|
||||
index_t kRowPerXdl_, // row number per xdl ops
|
||||
index_t kColPerXdl_> // col number per xdl ops
|
||||
struct TransposePipelineProblem
|
||||
{
|
||||
static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_,
|
||||
"the block size is not correct!");
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Layout = remove_cvref_t<Layout_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kLeadNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kLeadDim;
|
||||
static constexpr index_t kSecondNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
|
||||
static constexpr index_t kLeadSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
|
||||
static constexpr index_t kSecondSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
|
||||
static constexpr index_t kLeadSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kLeadDim;
|
||||
static constexpr index_t kSecondSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kSecondDim;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
|
||||
"block dim should be divided by warp dim!");
|
||||
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
|
||||
"block dim should be divided by warp dim!");
|
||||
// how many rows/cols implemented in one warp
|
||||
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
|
||||
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
|
||||
|
||||
static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0,
|
||||
"warp dim should be divided by xdl dim!");
|
||||
static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0,
|
||||
"warp dim should be divided by xdl dim!");
|
||||
|
||||
// warp rows/cols is divided into xdl.
|
||||
static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl;
|
||||
static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl;
|
||||
|
||||
static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
// xdl rows/cols is divided into quadrants.
|
||||
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim;
|
||||
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim;
|
||||
|
||||
static constexpr index_t kIterationsInSecondDim =
|
||||
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = TransposePolicy>
|
||||
struct BlockTranspose
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
using Layout = remove_cvref_t<typename Problem::Layout>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
|
||||
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename InputTileWindow, typename OutputTileWindow>
|
||||
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
|
||||
OutputTileWindow& output_window,
|
||||
void* __restrict__ p_smem)
|
||||
{
|
||||
auto input_tile_window =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
auto output_tile_window =
|
||||
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
DataType* p_lds_ptr = static_cast<DataType*>(p_smem);
|
||||
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
|
||||
auto input_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
|
||||
|
||||
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
|
||||
auto output_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
|
||||
|
||||
auto copy_to_lds_window =
|
||||
make_tile_window(input_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0});
|
||||
auto load_from_lds_window =
|
||||
make_tile_window(output_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeLdsLoadTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(input_tile_window);
|
||||
|
||||
store_tile(copy_to_lds_window, x);
|
||||
block_sync_lds();
|
||||
|
||||
auto y = load_tile_transpose(load_from_lds_window);
|
||||
|
||||
store_tile(output_tile_window, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
59
example/ck_tile/37_transpose/transpose_api.cpp
Normal file
59
example/ck_tile/37_transpose/transpose_api.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "transpose_example.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <typename ts_type,
|
||||
ck_tile::index_t block_x,
|
||||
ck_tile::index_t block_y,
|
||||
ck_tile::index_t warp_x,
|
||||
ck_tile::index_t warp_y>
|
||||
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
|
||||
{
|
||||
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
|
||||
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
|
||||
uint32_t dim_stride = a.height * a.width;
|
||||
|
||||
a.dim_stride = dim_stride;
|
||||
a.dim_block_h = dim_block_h;
|
||||
a.dim_block_w = dim_block_w;
|
||||
|
||||
using ts_problem = ck_tile::TransposePipelineProblem<ts_type,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
block_y,
|
||||
block_x,
|
||||
warp_y,
|
||||
warp_x>;
|
||||
using ts_pipeline = ck_tile::BlockTranspose<ts_problem>;
|
||||
|
||||
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
|
||||
|
||||
auto kargs = kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = kernel::GridSize(a);
|
||||
constexpr dim3 blocks = kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.type == "fp16")
|
||||
{
|
||||
return batched_transpose_dispatch<ck_tile::fp16_t, 16, 32, 16, 32>(a, s);
|
||||
}
|
||||
else if(t.type == "fp8")
|
||||
{
|
||||
return batched_transpose_dispatch<ck_tile::fp8_t, 16, 64, 16, 64>(a, s);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
257
example/ck_tile/37_transpose/transpose_example.cpp
Normal file
257
example/ck_tile/37_transpose/transpose_example.cpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "transpose_example.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
{
|
||||
auto len = x.get_lengths();
|
||||
assert(len.size() == 4);
|
||||
std::cout << "[";
|
||||
for(size_t i = 0; i < len[0]; i++)
|
||||
{
|
||||
std::cout << i << ": [";
|
||||
for(size_t j = 0; j < len[1]; j++)
|
||||
{
|
||||
std::cout << j << ": [";
|
||||
for(size_t k = 0; k < len[2]; k++)
|
||||
{
|
||||
std::cout << k << ": [";
|
||||
for(size_t v = 0; v < len[3]; v++)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
|
||||
{
|
||||
auto m =
|
||||
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
|
||||
|
||||
std::cout << m;
|
||||
if(v != len[3] - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "--------------------" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "whether do CPU validation or not")
|
||||
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("N", "2", "input batch size. ")
|
||||
.insert("C", "64", "input channel size.")
|
||||
.insert("H", "1", "input height size.")
|
||||
.insert("W", "64", "input width size. ")
|
||||
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
|
||||
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "t to 1 will print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
bool run_batched_transpose(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string prec = args.get_str("pr");
|
||||
int N = args.get_int("N");
|
||||
int C = args.get_int("C");
|
||||
int H = args.get_int("H");
|
||||
int W = args.get_int("W");
|
||||
std::string layout_in = args.get_str("layout_in");
|
||||
std::string layout_out = args.get_str("layout_out");
|
||||
int seed = args.get_int("seed");
|
||||
|
||||
int dim_in[4], dim_out[4];
|
||||
int stride_dim_in[4], stride_dim_out[4];
|
||||
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
|
||||
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
|
||||
assert(nchw2nhwc != nhwc2nchw);
|
||||
(void)nhwc2nchw;
|
||||
|
||||
dim_in[0] = N;
|
||||
dim_in[1] = nchw2nhwc ? C : H;
|
||||
dim_in[2] = nchw2nhwc ? H : W;
|
||||
dim_in[3] = nchw2nhwc ? W : C;
|
||||
dim_out[0] = N;
|
||||
dim_out[1] = nchw2nhwc ? H : C;
|
||||
dim_out[2] = nchw2nhwc ? W : H;
|
||||
dim_out[3] = nchw2nhwc ? C : W;
|
||||
stride_dim_in[0] = C * H * W;
|
||||
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
|
||||
stride_dim_in[2] = nchw2nhwc ? W : C;
|
||||
stride_dim_in[3] = 1;
|
||||
stride_dim_out[0] = C * H * W;
|
||||
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
|
||||
stride_dim_out[2] = nchw2nhwc ? C : W;
|
||||
stride_dim_out[3] = 1;
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<Type> x_host(
|
||||
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
|
||||
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
|
||||
ck_tile::HostTensor<Type> y_host(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
auto trait = batched_transpose_trait{prec, layout_in};
|
||||
|
||||
uint32_t height = nchw2nhwc ? C : H * W;
|
||||
uint32_t width = nchw2nhwc ? H * W : C;
|
||||
|
||||
batched_transpose_kargs karg = [&]() {
|
||||
batched_transpose_kargs a_;
|
||||
a_.p_input = x_dev.GetDeviceBuffer();
|
||||
a_.p_output = y_dev.GetDeviceBuffer();
|
||||
a_.batch = N;
|
||||
a_.height = height;
|
||||
a_.width = width;
|
||||
return a_;
|
||||
}();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true};
|
||||
|
||||
auto ms = batched_transpose(trait, karg, sc);
|
||||
|
||||
std::size_t num_operations = N * C * H * (W - 1);
|
||||
std::size_t num_bytes = N * C * H * W * sizeof(Type);
|
||||
|
||||
float ave_time = ms * 1E-3;
|
||||
float gb_per_sec = num_bytes / ms * 1.E-6;
|
||||
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
|
||||
|
||||
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
|
||||
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
|
||||
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
|
||||
prec.c_str(),
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
layout_in.c_str(),
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
y_dev.FromDevice(y_host.data());
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
// this host buffer will not copy to GPU, so no need use stride
|
||||
ck_tile::HostTensor<Type> y_ref(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
|
||||
|
||||
auto [rtol, atol] = get_elimit<Type>("");
|
||||
|
||||
rtn &= ck_tile::check_err(
|
||||
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
|
||||
}
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string prec = args.get_str("pr");
|
||||
|
||||
bool r = true;
|
||||
if(prec.compare("fp16") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::fp16_t>(args);
|
||||
}
|
||||
else if(prec.compare("fp8") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::fp8_t>(args);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type: " << prec << std::endl;
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
27
example/ck_tile/37_transpose/transpose_example.hpp
Normal file
27
example/ck_tile/37_transpose/transpose_example.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "batched_transpose_kernel.hpp"
|
||||
#include "block_transpose.hpp"
|
||||
#include "transpose_policy.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#pragma once
|
||||
|
||||
struct batched_transpose_trait
|
||||
{
|
||||
std::string type;
|
||||
std::string layout;
|
||||
};
|
||||
|
||||
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
151
example/ck_tile/37_transpose/transpose_policy.hpp
Normal file
151
example/ck_tile/37_transpose/transpose_policy.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TransposePolicy
|
||||
{
|
||||
static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize()
|
||||
{
|
||||
return 16 / sizeof(typename Problem::DataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return integer_least_multiple(
|
||||
sizeof(typename Problem::DataType) *
|
||||
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
SecondDimPerBlock,
|
||||
LeadDimPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
|
||||
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<remove_cvref_t<decltype(input_dstr)>,
|
||||
typename Problem::DataType>::OutDstrEncode;
|
||||
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
return block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
|
||||
{
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
// Extract base dimensions from the traits
|
||||
constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
// Calculate block-level dimensions
|
||||
constexpr index_t kLead = Problem::kLeadSizePerXdl;
|
||||
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
|
||||
constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp;
|
||||
constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp;
|
||||
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
|
||||
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
|
||||
|
||||
// Calculate repetitions of base pattern
|
||||
constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim;
|
||||
constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim;
|
||||
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
|
||||
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
|
||||
|
||||
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
|
||||
kSecondDimStrSub,
|
||||
kSecondDimIterations,
|
||||
kLeadRepetitions,
|
||||
1>();
|
||||
|
||||
constexpr auto input_tile_encode =
|
||||
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
|
||||
kLeadIterPerWarp,
|
||||
kSecondIterPerWarp,
|
||||
kLeadNumWarps,
|
||||
kSecondNumWarps>();
|
||||
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
|
||||
return block_dstr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -18,5 +18,8 @@ add_subdirectory(15_fused_moe)
|
||||
add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(36_copy)
|
||||
add_subdirectory(37_transpose)
|
||||
|
||||
Reference in New Issue
Block a user