Merge remote-tracking branch 'origin/vpietila/ckb-bwd-weight-factories' into vpietila/ckb-refactor-warp-gemm-descriptors

This commit is contained in:
Ville Pietilä
2026-01-07 06:16:43 -05:00
164 changed files with 11855 additions and 6728 deletions

View File

@@ -54,7 +54,7 @@ jobs:
with:
repository: "ROCm/TheRock"
path: "TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: Setup ccache
run: |
@@ -78,8 +78,9 @@ jobs:
run: |
git config --global --add safe.directory '*'
# Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo
rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
rm ./TheRock/patches/amd-mainline/rocm-libraries/0003-Find-rocm_smi-via-config-files.patch
rm ./TheRock/patches/amd-mainline/rocm-libraries/0007-Remove-Windows-third_party_dlls-copying-code.patch
# git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
- name: Install python deps
run: |

View File

@@ -51,7 +51,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: "ROCm/TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: Run setup test environment workflow
uses: './.github/actions/setup_test_environment'

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: "Configuring CI options"
env:

View File

@@ -5,12 +5,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## (Unreleased) Composable Kernel 1.3.0
### Added
* Added preshuffleB support for abquant mode in blockscale GEMM.
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
### Changed

View File

@@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest"
FROM $BASE_DOCKER
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN pip install pandas zmq einops ninja && \
RUN pip install pandas zmq einops ninja tabulate && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \
sudo mkdir /home/jenkins/workspace && \

6
Jenkinsfile vendored
View File

@@ -1046,7 +1046,7 @@ def run_aiter_tests(Map conf=[:]){
sh "rocminfo"
sh "python3 --version"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
//sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py"
@@ -1469,8 +1469,8 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \
./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases"""
make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \
./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)

View File

@@ -36,6 +36,19 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
}
KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_qscale},
{F_occupancy}>;
{F_occupancy},
false,
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_variant_{F_idx},
fmha_mask_{F_idx},
false,
{F_page_size},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
#include <iostream>
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
)
@property
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
trait.kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
trait.kv_lookup_table
],
F_page_size=trait.page_size,
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
@property
def template(self) -> str:
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_occupancy=self.F_tile.F_occupancy,
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
self.F_pipeline.F_kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
self.F_pipeline.F_kv_lookup_table
],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
)
@@ -604,23 +647,42 @@ class KernelComponentFactory:
pipelines = []
if dtype in ["fp16", "bf16"]:
qscale = "no"
for logits, mask, bias, lse, dropout in itertools.product(
for (
logits,
mask,
bias,
lse,
dropout,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for (
logits,
qscale,
mask,
bias,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
@@ -672,69 +734,73 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)

View File

@@ -529,14 +529,25 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
// SGLang-style page table
int32_t num_total_pages;
void* kv_indptr;
void* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
void* kv_last_page_lens;
ck_tile::index_t page_block_size;
#endif
// KV cache page table fields (kv_lookup_table selects interpretation):
// - SGLANG_PAGE_TABLE_1D:
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
// kv_last_page_lens: per-batch last page lengths [batch]
// - VLLM_BLOCK_TABLE_2D:
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
// batch_stride_block_table: row stride for block_table
// seqlen_k_ptr: per-batch seqlen_k [batch]
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
kv_memory_layout; // KV memory layout (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
float scale_s;
float scale_p;
@@ -1113,6 +1124,22 @@ template <typename FmhaKernel>
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
const PageTableKargs page_table = [&]() {
if constexpr(FmhaKernel::kKVLookupTable ==
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
reinterpret_cast<const int32_t*>(args.kv_page_indices),
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
}
else
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
args.batch_stride_block_table,
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
}
}();
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
@@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_
static constexpr bool kHasSink = kHasSink_;
};
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false,
ck_tile::index_t kPageBlockSize_ = 1,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
BiasEnum_,
kStoreLse_,
kHasDropout_,
QScaleEnum_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_,
kSkipMinSeqlenQ_,
false>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};
template <typename Traits_, typename Arch = void>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
@@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
using fmha_batch_prefill_traits = fmha_fwd_traits;
struct fmha_batch_prefill_traits : public fmha_fwd_traits
{
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
int page_size = 1;
};
float fmha_batch_prefill(fmha_batch_prefill_traits,
fmha_batch_prefill_args,
const ck_tile::stream_config&);

View File

@@ -69,107 +69,88 @@ struct BasicInvoker
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(args.k_batch == 1)
if(s.flush_cache_)
{
return Run(MemoryOpSet{});
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
return Run(MemoryOpAtomicAdd{});
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
WorkspaceType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
WorkspaceType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
auto c_ptr = ws_args.c_ptr;
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
auto c_ptr = ws_args.c_ptr;
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernel::BlockSize();
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernel::BlockSize();
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
WorkspaceType,
CDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
WorkspaceType,
CDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {args.M, args.N};
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {args.M, args.N};
for(auto d : shape)
total_elements *= d;
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize =
(total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
auto input_size = ck_tile::make_tuple(args.M, args.N);
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
auto input_size = ck_tile::make_tuple(args.M, args.N);
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
gemm_kargs.as_ptr[0],
gemm_kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
GemmKernel{}, grids, blocks, 0, gemm_kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(args.N, 1), // Input Stride
ck_tile::make_tuple(args.N, 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr)));
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
};
if(args.k_batch == 1)
if(s.flush_cache_)
{
return Run(MemoryOpSet{});
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
gemm_kargs.as_ptr[0],
gemm_kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
return Run(MemoryOpAtomicAdd{});
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
GemmKernel{}, grids, blocks, 0, gemm_kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(args.N, 1), // Input Stride
ck_tile::make_tuple(args.N, 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr)));
}
};

View File

@@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
args.stride_E);
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&]() {
// use SET operation since each K-split writes to separate memory
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(base_args);
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(base_args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};
return Run();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
}
/**

View File

@@ -460,12 +460,6 @@ inline auto create_args()
return arg_parser;
}
// Type aliases for memory operation integral constants
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>;
// host API
template <typename ADataType,
typename BDataType,

View File

@@ -57,114 +57,95 @@ struct WeightPreshuffleInvoker
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
<< std::endl;
}
float ave_time = 0.f;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time =
ck_tile::launch_kernel_time_mask(s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
if(args.k_batch == 1)
dim3 grids;
if constexpr(Persistent)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
throw std::runtime_error("split-k is not supported yet!");
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
}
float ave_time = 0.f;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
}
};

View File

@@ -60,112 +60,94 @@ struct UniversalInvoker
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups,
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups,
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
: Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
: Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(args.k_batch == 1)
if(s.flush_cache_)
{
return Run(MemoryOpSet{});
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
return Run(MemoryOpAtomicAdd{});
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
printf("moe_buf:%" PRIu64 "(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
printf("moe_buf:%" PRIu64 ", ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}

View File

@@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "run_batched_gemm_example.inc"

View File

@@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
quant_grouped_gemm_bf8_rowcol.cpp
quant_grouped_gemm_bf8_tensor.cpp
)
add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp)
add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
@@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,278 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include <type_traits>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "abquant_grouped_gemm.hpp"
// Non-persistent grouped gemm for ABQuant
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode>
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
AQuantGroupSize,
BQuantGroupSize,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
// Persistent grouped gemm tileloop for ABQuant
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
AQuantGroupSize,
BQuantGroupSize,
GemmConfig::TransposeC>;
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}
#include "run_grouped_gemm_abquant_example.inc"
int main(int argc, char* argv[])
{
int result1 = run_abquant_grouped_gemm_example(argc, argv);
return result1;
}

View File

@@ -0,0 +1,171 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename DataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <bool Persistent_>
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
static constexpr bool Persistent = Persistent_;
};
template <typename PrecType, bool Persistent>
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <ck_tile::QuantType QuantMode>
struct GemmQuantConfig;
// ABQuant specialization for GemmQuantConfig
template <>
struct GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert(
"stride_As",
"",
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
// can be set to zero if
// Ms/Ns/Ks is not empty
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.")
.insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
}
// Forward declaration of the non-persistent version
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);
// Forward declaration of the tileloop version for persistent kernels
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr);

View File

@@ -62,71 +62,55 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Kernel arguments not supported!");
}
else
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
template <typename GemmConfig,
@@ -139,8 +123,7 @@ template <typename GemmConfig,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BLayout,
CLayout>;
float ave_time{0};
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(!splitk)
if(s.log_level_ > 0)
{
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return ave_time =
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}
#include "run_grouped_gemm_example.inc"

View File

@@ -328,5 +328,4 @@ template <typename GemmConfig,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk = false);
void* kargs_ptr);

View File

@@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Kernel arguments not supported!");
}
else
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
template <typename GemmConfig,
@@ -142,8 +126,7 @@ template <typename GemmConfig,
typename CDEElementWise>
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
BLayout,
ELayout>;
float ave_time{0};
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(!splitk)
if(s.log_level_ > 0)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}
#include "run_grouped_gemm_multi_d_example.inc"

View File

@@ -65,70 +65,54 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Kernel arguments not supported!");
}
else
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
template <typename GemmConfig,
@@ -141,8 +125,7 @@ template <typename GemmConfig,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
float ave_time{0};
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(splitk)
if(s.log_level_ > 0)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}
#include "run_grouped_gemm_example.inc"

View File

@@ -72,10 +72,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
@@ -137,8 +136,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
@@ -224,90 +222,79 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
float ave_time{0};
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::TransposeC,
BDataType,
scheduler>>;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::TransposeC,
BDataType,
scheduler>>;
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}

View File

@@ -0,0 +1,604 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_abquant_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
{
// Workspace memory allocated to hold the gemm descriptions.
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
if constexpr(!GemmConfig::Persistent)
{
ave_time = grouped_gemm_abquant<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
}
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
}
return ave_time;
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename CDataType,
typename AccDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_abquant_grouped_gemm_example_with_layouts(
int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
auto valid_input_data = [&](int group_count, const auto&... args) {
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");
if(kbatch > 1 && validate && warmup + repeat > 1)
{
std::cout << "WARNING: Data validation enabled with SplitK and more than"
<< "1 warmup/repeat. Disabling validation." << std::endl;
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
ck_tile::index_t AQK, BQK;
if(!valid_input_data(
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
// Clear existing (invalid) data before adding defaults
Ms.clear();
Ns.clear();
Ks.clear();
stride_As.clear();
stride_Bs.clear();
stride_Cs.clear();
stride_AQs.clear();
stride_BQs.clear();
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
// Let get_default_stride calculate based on layout
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
aq_tensors.reserve(group_count);
bq_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
aq_dev_buf.reserve(group_count);
bq_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
// For ABQuantGrouped, both A and B need quantization
static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped,
"This file only supports ABQuantGrouped mode");
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize
if(K % AQuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode");
}
if(K % BQuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode");
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
}
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
aq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
bq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a,
p_b,
p_c,
p_aq,
p_bq,
kbatch,
M,
N,
K,
AQK,
BQK,
stride_As[i],
stride_Bs[i],
stride_Cs[i],
stride_AQs[i],
stride_BQs[i]});
}
float ave_time = invoke_abquant_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(warmup, repeat, group_count, gemm_descs);
std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
if(validate)
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Reference implementation for ABQuantGrouped
ck_tile::reference_gemm_abquant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &=
ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_grouped_gemm_json_results<ALayout, BLayout, CLayout>(arg_parser.get_str("jsonfile"),
op_name,
group_count,
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass;
}
template <typename PrecType, typename GemmConfig, typename BQuantGroupSize>
int run_abquant_grouped_gemm_example_prec_type_with_bquant(
std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
using AQDataType = typename Types::AccDataType;
using BQDataType = typename Types::AccDataType;
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped;
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
{
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
{
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
{
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
argc, argv, Col{}, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <typename PrecType, typename GemmConfig>
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
std::string c_layout,
std::string bquant_group_size,
int argc,
char* argv[])
{
if(bquant_group_size == "1x1x128")
{
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
GemmConfig,
BQuantGroupSize>(
a_layout, b_layout, c_layout, argc, argv);
}
else if(bquant_group_size == "1x128x128")
{
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
GemmConfig,
BQuantGroupSize>(
a_layout, b_layout, c_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128.");
}
}
template <typename PrecType>
int run_abquant_gemm_example_persistency(std::string a_layout,
std::string b_layout,
std::string c_layout,
bool persistent,
std::string bquant_group_size,
int argc,
char* argv[])
{
if(persistent)
{
using GemmConfig = typename GemmQuantConfig<
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, true>;
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
}
else
{
using GemmConfig = typename GemmQuantConfig<
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, false>;
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
}
}
int run_abquant_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string c_layout = arg_parser.get_str("c_layout");
const std::string data_type = arg_parser.get_str("prec");
bool persistent = arg_parser.get_bool("persistent");
const std::string bquant_group_size = arg_parser.get_str("bquant_group_size");
if(data_type == "fp8")
{
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
}
else if(data_type == "bf8")
{
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup,
// earlier stage.
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
@@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup,
ADataType,
BDataType,
AccDataType,
CDataType>(stream, group_count, kargs_ptr, splitk);
CDataType>(stream, group_count, kargs_ptr);
}
return ave_time;

View File

@@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup,
else
{
std::vector<ck_tile::GemmTransKernelArg<NumDTensor>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr},
@@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup,
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time =
grouped_gemm_multi_d_tileloop<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
ave_time = grouped_gemm_multi_d_tileloop<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(stream, group_count, kargs_ptr);
}
return ave_time;
}

View File

@@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
@@ -207,7 +204,6 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
@@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
@@ -150,7 +147,6 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
@@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config&
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
using CodegenPipelineProblem =
std::conditional_t<MXFP4_Pipeline,
@@ -159,7 +156,6 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
@@ -265,23 +261,7 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
@@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC
@@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
@@ -184,7 +180,6 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
@@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
return ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation =
Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set;
ck_tile::ignore = Splitk;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
@@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC

View File

@@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
// Epilogue selection: set to true for chainer-based, false for standard
// CShuffleEpilogue
constexpr bool UseChainerEpilogue = true;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using GemmEpilogue = std::conditional_t<
UseChainerEpilogue,
// Chainer-based epilogue
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
ck_tile::CShuffleEpilogueChainProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>,
ck_tile::DefaultScheduleTag>>,
// Standard CShuffleEpilogue
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
<< ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "run_gemm_multi_d_fp16_example.inc"

View File

@@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
InDataType,
DsDataType,
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
InDataType,
DsDataType,
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
const auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(args.k_batch > 1)
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
ck_tile::hip_check_error(hipMemsetAsync(
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr,
0,
args.template GetWeightByte<WeiDataType>(),
s.stream_id_));
}
};
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
};
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return InvokerResult{ave_time, args.k_batch};
}
};

View File

@@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
InDataType, // B: In
DsDataType,
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
InDataType, // B: In
DsDataType,
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
const auto kargs = Kernel::MakeKernelArgs(ws_args);
const auto kargs = Kernel::MakeKernelArgs(ws_args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
WorkspaceDataType,
WeiDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
WorkspaceDataType,
WeiDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(args.G_ * args.K_),
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(args.G_ * args.K_),
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
for(auto d : shape)
total_elements *= d;
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize =
(total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
auto input_tensors =
ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(kargs.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s.stream_id_));
};
const auto ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
const auto split_k = kargs.k_batch;
return InvokerResult{ave_time, split_k};
auto preprocess = [&]() {
if(args.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s.stream_id_));
};
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return InvokerResult{ave_time, kargs.k_batch};
}
};

View File

@@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker
// =====================================================================
// Regular Convolution: Simple, no split-image
// =====================================================================
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
// =====================================================================
// Split-K dispatch
// =====================================================================
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(MemoryOpSet{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
else
if(s.log_level_ > 0)
{
return Run(MemoryOpAtomicAdd{});
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker
// =====================================================================
// Kernel launch lambda: Uses EnableSplitImage based on layout support
// =====================================================================
const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) {
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto enable_split_image_) {
constexpr bool EnableSplitImage = enable_split_image_.value;
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
@@ -255,7 +254,6 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
@@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker
// =====================================================================
if(use_split_image)
{
if(args.k_batch == 1)
return Run(MemoryOpSet{}, ck_tile::bool_constant<true>{});
else
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<true>{});
return Run(ck_tile::bool_constant<true>{});
}
else
{
if(args.k_batch == 1)
return Run(MemoryOpSet{}, ck_tile::bool_constant<false>{});
else
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<false>{});
return Run(ck_tile::bool_constant<false>{});
}
}
};

View File

@@ -13,11 +13,6 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "conv_configs.hpp"
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>;
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
const ck_tile::index_t kbatch,

View File

@@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
<< ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "run_gemm_multi_abd_fp16_example.inc"

View File

@@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory(
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"preshuffleb",
"non-preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"preshuffleb",
"non-preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"preshuffleb",
"non-preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"preshuffleb",
"non-preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}

View File

@@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
void bquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
@@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,

View File

@@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>>>>;
using AQuantPipeline =
std::conditional_t<GemmConfig::PreshuffleQuant,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
using BQuantPipeline = std::conditional_t<
GemmConfig::PreshuffleB,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using ABQuantPipeline =
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
std::conditional_t<GemmConfig::PreshuffleQuant == true,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
std::conditional_t<
QuantMode == ck_tile::QuantType::ABQuantGrouped,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>>;
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
AQuantPipeline,
std::conditional_t<QuantMode == ck_tile::QuantType::ABQuantGrouped,
ABQuantPipeline,
BQuantPipeline>>>;
constexpr bool TiledPermuteN =
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
@@ -173,77 +181,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
// Epilogue selection: use chainer for RowCol/Tensor quant, standard for others
// Toggle to switch between chainer-based and standard CShuffleEpilogue
constexpr bool UseChainerEpilogue = true;
// Define the schedule tag based on quant mode
using ScheduleTag =
std::conditional_t<QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::RowColQuantScheduleTag,
std::conditional_t<QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::TensorQuantScheduleTag,
ck_tile::DefaultScheduleTag>>;
using GemmEpilogue = std::conditional_t<
UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant),
// Chainer-based epilogue for RowCol/Tensor quant modes
ck_tile::EpilogueChainer<ck_tile::CshuffleEpilogueSchedule<
ck_tile::CShuffleEpilogueChainProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledPermuteN>,
ScheduleTag>>,
// Standard CShuffleEpilogue for other modes
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledPermuteN>>>;
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
@@ -579,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
@@ -955,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if((QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::AQuantGrouped ||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
GemmConfig::PreshuffleB)
@@ -985,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::PreshuffleQuant)
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
{
if(a_layout == "R" && b_layout == "R")
{

View File

@@ -48,112 +48,87 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfiguration::NUM_WAVE_GROUPS,
GemmConfiguration::PRESHUFFLE>;
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
memory_operation.value,
GemmConfiguration::NUM_WAVE_GROUPS>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
GemmConfiguration::NUM_WAVE_GROUPS>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kernel_args))
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
if(stream_config.log_level_ > 0)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
}
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kernel_args))
auto reset_data_buffers = [&]() {
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
if(stream_config.log_level_ > 0)
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
auto reset_data_buffers = [&]() {
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};
std::function<void()> preprocess = reset_data_buffers;
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile =
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
{
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the
// same output tile in the C tensor, so we must
// atomic add the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing
// final results to each macro tile in the C
// tensor, so we can do a set.
ck_tile::memory_operation_enum::set>{});
}
std::function<void()> preprocess = reset_data_buffers;
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
}
#include "run_gemm_example.inc"

View File

@@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
const auto Run = [&]() {
constexpr auto memory_operation =
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel =
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel =
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::GetBlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::GetBlockSize();
if(!Kernel::IsSupportedArguments(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
}
if(!Kernel::IsSupportedArguments(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
return ck_tile::launch_kernel(s, kernel);
};
return Run();
return ck_tile::launch_kernel(s, kernel);
}
#define HANDLE_CASE(G, M, N, K) \

View File

@@ -51,13 +51,13 @@ struct ConvBwdWeightWmmaFactory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>,
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>,
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>,
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>,
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.

View File

@@ -112,7 +112,7 @@ constexpr auto make_conv_instance()
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
// CK Tile supports common factory for each direction
if constexpr(TileAlgorithm<AlgoType>)
else if constexpr(TileAlgorithm<AlgoType>)
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}

View File

@@ -116,7 +116,6 @@ struct ConvTileFactory
BLOCK_GEMM.warp_tile.k,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
// TODO:: This template parameter will be moved inside the kernel
ck_tile::memory_operation_enum::set,
BLOCK_GEMM.num_wave_groups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
SCALAR_PER_VECTOR.c>>;

View File

@@ -47,6 +47,11 @@ struct DataTypeToCK<DataType::FP8>
{
using type = ck::f8_t;
};
template <>
struct DataTypeToCK<DataType::U8>
{
using type = uint8_t;
};
struct CK_empty_tuple
{

View File

@@ -125,9 +125,9 @@ struct ReferenceFactory
// Direct Run method (simpler interface, direction-agnostic)
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
static void Run(InPtrType input,
WeiPtrType weight,
OutPtrType output,
static void Run(InPtrType* input,
WeiPtrType* weight,
OutPtrType* output,
int G,
int N,
int K,
@@ -142,9 +142,9 @@ struct ReferenceFactory
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<const InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<OutDataType*>(output),
G,
N,
K,
@@ -160,9 +160,9 @@ struct ReferenceFactory
{
ck_tile::
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
@@ -179,19 +179,20 @@ struct ReferenceFactory
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(input,
weight,
output,
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
OutDataType>(
static_cast<const InDataType*>(input),
static_cast<WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
}

View File

@@ -7,11 +7,14 @@
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/extent.hpp"
#include "ck_tile/builder/testing/filter_extent.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/validation.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
/// This file implements common functionality for invoking/testing grouped
/// forward convolutions created through the CK Builder API. The main item
/// of it is the ConvArgs structure - which contains a complete description
@@ -37,12 +40,12 @@ namespace ck_tile::builder::test {
template <int SPATIAL_DIM>
struct ConvTensorLengths
{
size_t batch_size = 1; // N
size_t groups = 1; // G
size_t input_channels = 1; // C
size_t output_channels = 1; // K
Extent<SPATIAL_DIM> image = {}; // W, H, D
Extent<SPATIAL_DIM> filter = {}; // X, Y, Z
size_t batch_size = 1; // N
size_t groups = 1; // G
size_t input_channels = 1; // C
size_t output_channels = 1; // K
FilterExtent<SPATIAL_DIM> image = {}; // W, H, D
FilterExtent<SPATIAL_DIM> filter = {}; // X, Y, Z
};
/// @brief `Args` specialization for forward convolution.
@@ -59,6 +62,14 @@ struct Args<SIGNATURE>
constexpr static auto WEIGHT_TYPE = SIGNATURE.data_type;
constexpr static auto OUTPUT_TYPE = SIGNATURE.data_type;
constexpr static int INPUT_RANK = 3 + SPATIAL_DIM;
constexpr static int WEIGHT_RANK = 3 + SPATIAL_DIM;
constexpr static int OUTPUT_RANK = 3 + SPATIAL_DIM;
using InputDescriptor = TensorDescriptor<INPUT_TYPE, INPUT_RANK>;
using WeightDescriptor = TensorDescriptor<WEIGHT_TYPE, WEIGHT_RANK>;
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
// TODO: We shouldn't need to call into an internal namespace here.
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
@@ -72,10 +83,10 @@ struct Args<SIGNATURE>
// implementation (based on ConvParam in old CK / CK Tile) does not
// support strides at all.
Extent<SPATIAL_DIM> filter_strides;
Extent<SPATIAL_DIM> filter_dilation;
Extent<SPATIAL_DIM> input_left_pad;
Extent<SPATIAL_DIM> input_right_pad;
FilterExtent<SPATIAL_DIM> filter_strides;
FilterExtent<SPATIAL_DIM> filter_dilation;
FilterExtent<SPATIAL_DIM> input_left_pad;
FilterExtent<SPATIAL_DIM> input_right_pad;
Ops::AElementwiseOp a_elementwise_op;
Ops::BElementwiseOp b_elementwise_op;
@@ -84,7 +95,7 @@ struct Args<SIGNATURE>
/// This function returns the `TensorDescriptor` corresponding to
/// the input-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<INPUT_TYPE> make_input_descriptor() const
InputDescriptor make_input_descriptor() const
{
// TODO: We're using old CK functionality to compute the right
// values here, mainly because CK tile does not support the
@@ -95,31 +106,37 @@ struct Args<SIGNATURE>
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
typename Layouts::ALayout>(param);
return TensorDescriptor<INPUT_TYPE>(desc.GetLengths(), desc.GetStrides());
using Extent = typename InputDescriptor::Extent;
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
}
/// This function returns the `TensorDescriptor` corresponding to
/// the weight-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<WEIGHT_TYPE> make_weight_descriptor() const
WeightDescriptor make_weight_descriptor() const
{
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
typename Layouts::BLayout>(param);
return TensorDescriptor<WEIGHT_TYPE>(desc.GetLengths(), desc.GetStrides());
using Extent = typename WeightDescriptor::Extent;
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
}
/// This function returns the `TensorDescriptor` corresponding to
/// the output-tensor of the convolution problem. This can then
/// be used to, for example, allocate memory.
TensorDescriptor<OUTPUT_TYPE> make_output_descriptor() const
OutputDescriptor make_output_descriptor() const
{
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
typename Layouts::ELayout>(param);
return TensorDescriptor<OUTPUT_TYPE>(desc.GetLengths(), desc.GetStrides());
using Extent = typename OutputDescriptor::Extent;
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
}
/// Convert the Args structure into a CK conv_param structure. This
@@ -244,12 +261,11 @@ UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
///
/// @see alloc_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
ValidUniqueInputs<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs)
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
}
/// @brief `alloc_outputs()` specialization for forward convolution.
@@ -267,4 +283,19 @@ UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
};
}
/// @brief `validate()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see validate()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
ValidationReport
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected)
{
ValidationReport report;
report.check("output", args.make_output_descriptor(), actual.output, expected.output);
return report;
}
} // namespace ck_tile::builder::test

View File

@@ -3,10 +3,10 @@
#pragma once
#include <span>
#include <cstddef>
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include <type_traits>
#include <array>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations in old CK. The main item is the
@@ -15,6 +15,63 @@
namespace ck_tile::builder::test {
namespace detail {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
/// with some utility aliases. For that reason, its moved to this detail
/// namespace.
template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
const void* p_a,
const void* p_b,
void* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::AElementwiseOp elementwise_a,
Ops::BElementwiseOp elementwise_b,
Ops::CDEElementwiseOp elementwise_cde) {
{
conv.MakeArgument(p_a,
p_b,
// TODO: Support multiple D outputs.
{},
p_e,
// A lengths/strides
lengths,
strides,
// B lengths/strides
lengths,
strides,
// TODO: Ds lengths/strides
{},
{},
// E lengths/strides
lengths,
strides,
// strides/dilations/pads
filter,
filter,
filter,
filter,
// element-wise operations.
elementwise_a,
elementwise_b,
elementwise_cde)
};
};
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like old CK.
///
/// This concept is used to tell whether a convolution implementation is
@@ -24,13 +81,8 @@ namespace ck_tile::builder::test {
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <auto SIGNATURE, typename Conv>
concept IsCkConvInstance =
// TODO: This should be implemented by converting the signature into the
// type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave
// it empty. Improve when needed, you get the point. Also we should probably
// move this to the ck conv factory helper.
true;
template <typename Conv, auto SIGNATURE>
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
/// @brief `run()` specialization for forward convolution and old CK.
///
@@ -39,10 +91,9 @@ concept IsCkConvInstance =
/// operation. This should be caught and reported by the testing framework.
///
/// @see run()
template <auto SIGNATURE, typename Conv>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
IsCkConvInstance<SIGNATURE, Conv>
void run(Conv& conv,
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void run(CkConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include <stdexcept>
#include <vector>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations using the reference implementation.
/// The main item is the `run()` function, which is the primary way to
/// invoke the reference execution mechanism.
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
/// but its made specific to the reference implementation, which is
/// invoked in a slightly different way.
namespace ck_tile::builder::test {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be the reference implementation - that is, whether we should
/// invoke it like the reference kernel. This is mainly used with `run()` to
/// differentiate which implementation that should be invoked.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept RefConvInstance = requires(Conv& conv,
const void* input,
const void* weight,
void* output,
int G,
int N,
int K,
int C,
std::vector<long_index_t> dims) {
{
conv.Run(input,
weight,
output,
G,
N,
K,
C,
dims, // input_spatial
dims, // filter_spatial
dims, // output_spatial
dims, // strides
dims, // dilations
dims // left_pads
)
};
};
/// @brief `run()` specialization for forward convolution and the reference
/// implementation.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments weren't actually valid for the
/// operation. This should be caught and reported by the testing framework.
///
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> &&
// TODO: Maybe we can unify this implementation for bwd/weight too?
// for now, just concern outselves with reference and see when the
// rest of the bwd/weight plumbing is there.
ConvDirectionIsForward<SIGNATURE>
void run(RefConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
// We don't want to compute the output dims manually, just get
// them via the existing infrastructure
const auto param = args.to_ck_conv_param();
// TODO: The reference convolution is currently missing a few features.
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
// Right pads are not supported right now for some reason.
for(auto right_pad : param.input_right_pads_)
{
if(right_pad != 0)
throw std::runtime_error("TODO: Support right pad in reference conv");
}
if(!args.make_input_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed input tensor in reference conv");
if(!args.make_weight_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv");
if(!args.make_output_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed output tensor in reference conv");
conv.Run(inputs.input,
inputs.weight,
outputs.output,
param.G_,
param.N_,
param.K_,
param.C_,
param.input_spatial_lengths_,
param.filter_spatial_lengths_,
param.output_spatial_lengths_,
param.conv_filter_strides_,
param.conv_filter_dilations_,
param.input_left_pads_);
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,150 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <hip/hip_runtime.h>
#include <source_location>
#include <stdexcept>
#include <sstream>
/// This file defines some utilities for dealing with HIP errors. In the CK-Builder
/// testing code, we'd like to just turn them into exceptions: This cleans up testing
/// code as we don't need to think about returning error codes, but its still much
/// cleaner than just creating a hard crash and thereby possibly interrupting other
/// units in the same test. The testing framework can catch these exceptions where
/// necessary.
///
/// While the exceptions defined in this file are in principle suitable for general
/// usage, HIP functions which return HIP error codes (`hipError_t`) should be
/// checked using the `check_hip` function.
namespace ck_tile::builder::test {
/// @brief Generic HIP exception.
///
/// This is a derivation of `std::runtime_error` which represents a HIP error code.
///
/// @see std::runtime_error
/// @see hipError_t
struct HipError : std::runtime_error
{
/// @brief Utility for formatting HIP error messages
///
/// Returns a human-readable description of a HIP error. Given a description of the
/// activity that the user tried to perform, this function appends the HIP-specific
/// information such as the stringified version of the error code, and the error
/// code itself (for reference).
///
/// @param user_msg User-given message about the activity at time of error.
/// @param code The status to report.
/// @param src The location where this error was discovered.
static std::string
format_error(std::string_view user_msg, hipError_t code, std::source_location src)
{
std::stringstream msg;
msg << user_msg << ": " << hipGetErrorString(code) << " (" << code << ")";
if(src.function_name())
msg << " in function '" << src.function_name();
msg << "' at " << src.file_name() << ":" << src.line() << ":" << src.column();
return msg.str();
}
/// @brief Construct a generic HIP error.
///
/// @param msg User-given message about the activity at time of error.
/// @param code The status to report.
/// @param src The location where this error was discovered. Defaults to the caller's
/// location.
HipError(std::string_view msg,
hipError_t code,
std::source_location src = std::source_location::current())
: std::runtime_error(format_error(msg, code, src)), code_(code)
{
}
/// @brief Retrieve the inner error code.
///
/// This function returns the status code that was encountered while checking an
/// operation for errors.
hipError_t code() const { return code_; }
private:
hipError_t code_;
};
/// @brief HIP out of memory error.
///
/// This a derivation of `HipError` which is specialized for Out-of-memory errors. This
/// makes it easier to attach additional context, and to match on these errors while
/// using `catch` blocks.
///
/// @see HipError
struct OutOfDeviceMemoryError : HipError
{
/// @brief Construct an out-of-device-memory error.
///
/// @param msg User-given message about the activity at time of error.
/// @param src The location where this error was discovered. Defaults to the caller's
/// location.
OutOfDeviceMemoryError(std::string_view msg = "failed to allocate device memory",
std::source_location src = std::source_location::current())
: HipError(msg, hipErrorOutOfMemory, src)
{
}
};
/// @brief Check HIP status for errors.
///
/// This function checks a HIP status code (obtained from a HIP function call) for any
/// errors. If the status `code` is not `hipSuccess`, this function throws an instance of
/// `HipError`. The exact type thats thrown depends on the status. If `code` represents
/// an out-of-memory error `hipErrorOutOfMemory`, then `OutOfDeviceMemoryError` will be
/// thrown instead.
///
/// @param msg User-given message about the activity at possible time of error.
/// @param code The HIP status code to examine.
/// @param src The location where this status was set. Defaults to the caller's location.
///
/// @throws HipError if `code` is not `hipSuccess`.
///
/// @see HipError
/// @see OutOfDeviceMemoryError
inline void check_hip(std::string_view msg,
hipError_t code,
std::source_location src = std::source_location::current())
{
// -Wswitch-enum throws a warning if this code is changed into a switch, even with
// the `default` label...
if(code == hipSuccess)
// When you beat the error allegations
return;
else if(code == hipErrorOutOfMemory)
throw OutOfDeviceMemoryError(msg, src);
else
throw HipError(msg, code, src);
}
/// @brief Check HIP status for errors.
///
/// This function is similar to `check_hip(std::string_view, hipError_t)`, except that a
/// default message is given.
///
/// @param code The HIP status code to examine.
/// @param src The location where this status was set. Defaults to the caller's location.
///
/// @throws HipError if `code` is not `hipSuccess`.
///
/// @see HipError
/// @see OutOfDeviceMemoryError
/// @see check_hip(std::string_view, hipError_t)
inline void check_hip(hipError_t code, std::source_location src = std::source_location::current())
{
check_hip(code == hipErrorOutOfMemory ? "failed to allocate device memory"
: "HIP runtime error",
code,
src);
}
} // namespace ck_tile::builder::test

View File

@@ -5,28 +5,29 @@
namespace ck_tile::builder::test {
/// This structure describes a 1-, 2-, or 3-D extent. Its used to
/// communicate 1-, 2- or 3-D sizes and strides of tensors.
/// Depending on the dimension, the structure will have the `width`,
/// `height`, and `depth` fields available.
/// This structure describes a 1-, 2-, or 3-D extent for convolution
/// filters. Its used to communicate 1-, 2- or 3-D sizes and strides
/// of tensors, specifically for convolution filters. Depending on the
/// dimension, the structure will have the `width`, `height`, and
/// `depth` fields available.
template <int SPATIAL_DIM>
struct Extent;
struct FilterExtent;
template <>
struct Extent<1>
struct FilterExtent<1>
{
size_t width = 1;
};
template <>
struct Extent<2>
struct FilterExtent<2>
{
size_t width = 1;
size_t height = 1;
};
template <>
struct Extent<3>
struct FilterExtent<3>
{
size_t width = 1;
size_t height = 1;

View File

@@ -3,19 +3,15 @@
#pragma once
#include "ck_tile/builder/testing/error.hpp"
#include <hip/hip_runtime.h>
#include <stdexcept>
#include <memory>
#include <numeric>
#include <span>
#include <concepts>
#include <hip/hip_runtime.h>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <sstream>
/// This file deals with tensor memory allocation: Both the act of allocating
/// and (automatically) deallocating memory, as well as utilities for managing
/// the layout of tensor data in memory.
/// This file deals with tensor memory management and allocation. The main
/// item is the `DeviceBuffer`: An owned piece of device memory, which is
/// automatically freed when it goes out of scope.
namespace ck_tile::builder::test {
@@ -39,31 +35,6 @@ struct DeviceMemoryDeleter
}
};
/// @brief HIP out of memory error
///
/// This is a derivation of `std::runtime_error` specialized for HIP
/// out-of-memory errors.
///
/// @see std::runtime_error
struct OutOfDeviceMemoryError : std::runtime_error
{
/// @brief Utility for formatting out-of-memory error messages
///
/// Returns a human-readable description of a HIP out-of-memory error.
///
/// @param status The status to report
static std::string format_error(hipError_t status)
{
return std::string("failed to allocate hip memory: ") + hipGetErrorString(status) + " (" +
std::to_string(status) + ")";
}
/// @brief Construct an out-of-memory error using `status` as message.
///
/// @param status A HIP error status that was encountered while allocating memory.
OutOfDeviceMemoryError(hipError_t status) : std::runtime_error(format_error(status)) {}
};
/// @brief Automatically managed GPU memory.
///
/// The `DeviceBuffer` is an automatically managed pointer for GPU memory. When
@@ -96,117 +67,18 @@ inline DeviceBuffer alloc_buffer(size_t size)
std::byte* d_buf = nullptr;
if(const auto status = hipMalloc(&d_buf, size); status != hipSuccess)
{
throw OutOfDeviceMemoryError(status);
// Add some additional context
size_t free, total;
check_hip("failed to get HIP memory info", hipMemGetInfo(&free, &total));
std::stringstream ss;
ss << "failed to allocate device memory (tried to allocate " << size << " bytes with only "
<< free << " available)";
throw OutOfDeviceMemoryError(ss.str());
}
return DeviceBuffer(d_buf);
}
/// @brief Type managing tensor data layout in memory.
///
/// This structure describes a tensor in memory. It does not actually hold any
/// reference to memory, it just describes how the memory should be laid out if it
/// were.
///
/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it
/// also includes the data type of the elements of htis tensor. This is mainly to
/// make the descriptor a _complete_ description of a tensor rather than just the
/// dimensions in strides, which helps in reducing clutter in uses of this type.
///
/// @note All strides are still in _elements_.
///
/// @tparam DT The conceptual data type of the tensor elements. This need not be the
/// type that the data is actually stored as in memory.
template <DataType DT>
struct TensorDescriptor
{
// For now, the implementation of this type is based on
// `ck_tile::HostTensorDescriptor`, so that we can prototype without
// reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard
// the use of `ck_tile::HostTensorDescriptor` here as an implementation detail.
/// The conceptual data type of the tensor elements. This need not be the type
/// that the data is actually stored as in memory.
constexpr static DataType data_type = DT;
/// @brief Create a tensor descriptor from lengths and strides.
///
/// @param lengths A sequence of tensor lengths, the conceptial dimensions of
/// the tensor in elements.
/// @param strides A sequence of in-memory strides of the tensor, measured in
/// elements. Each element of `strides`` corresponds to one at the same index
/// in `lengths`, the amount of elements to skip in memory to find the next
/// element along that axis.
TensorDescriptor(std::span<const size_t> lengths, std::span<const size_t> strides)
: inner_descriptor_(lengths, strides)
{
// TODO: Validation of strides? For now we just delegate the details of the
// construction to the CK Tile HostTensorDescriptor.
}
/// Query the conceptual dimensions of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Note that the order
/// does *not* correspond with memory layout, query the in-memory strides for
/// that.
///
/// @see get_strides()
std::span<const size_t> get_lengths() const { return inner_descriptor_.get_lengths(); }
/// Query the in-memory strides of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Each element
/// corresponds directly with the stride in elements at the same index in the
/// tensor dimensions.
///
/// @see get_lengths()
std::span<const size_t> get_strides() const { return inner_descriptor_.get_strides(); }
/// @brief Compute total tensor size in elements.
///
/// This function returns the total size of the memory backing a tensor with
/// this descriptor in *elements*, including required extra size for strides.
///
/// @see get_element_space_size_in_bytes()
size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); }
/// @brief Compute total tensor size in bytes.
///
/// This function is like `get_element_space_size()`, except that the returned
/// value is measured in *bytes* rather than *elements*. Use this function for
/// figuring out how much memory needs to be allocated for a particular tensor.
///
/// @see get_element_space_size()
size_t get_element_space_size_in_bytes() const
{
// For now, the backing type is the naive C++-type that represents the data
// type. When we are going to support packed types such as i4 and fp6, this
// is going to become more complicated.
return get_element_space_size() * data_type_sizeof(DT);
}
private:
ck_tile::HostTensorDescriptor inner_descriptor_;
};
/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor.
///
/// This function is similar to `alloc_buffer()`, except that the required size is
/// derived automatically from a tensor descriptor. The returned buffer is valid for
/// tensors with that layout. Strides are also taken into account when computing the
/// required size.
///
/// @tparam DT The conceptual datatype of the elements of the tensor.
/// @param descriptor A descriptor of the memory layout of the tensor to allocate.
/// @throws OutOfDeviceMemoryError if memory allocation failed.
///
/// @see TensorDescriptor
/// @see DeviceBuffer
/// @see OutOfDeviceMemoryError
/// @see hipMalloc()
template <DataType DT>
DeviceBuffer alloc_tensor_buffer(const TensorDescriptor<DT>& descriptor)
{
return alloc_buffer(descriptor.get_element_space_size_in_bytes());
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,474 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <stdexcept>
#include <array>
#include <vector>
#include <sstream>
#include <concepts>
#include <algorithm>
#include <hip/hip_runtime.h>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/host/host_tensor.hpp"
/// This file deals with tensor memory layout. The `TensorDescriptor` is the
/// main item, which is a type that describes (but not manages!) the layout
/// of tensor memory. There are also some related utilities.
namespace ck_tile::builder::test {
/// @brief Tensor dimensions type
///
/// An Extent describes size in tensor space, usually either the tensor lengths
/// (conceptual size) or the tensor strides (memory layout). This type is mainly
/// used by the `TensorDescriptor`. This type is based on `std::array<size_t, RANK>`
/// and supports all relevant operations on that.
///
/// @note In practical terms, this type is not just an alias of `std::array` for
/// two reasons: First, writing a separate type allows us to write a custom
/// CTAD deduction guideline. This allows users to write `Extent{1, 2, 3}` and
/// get an instance of the correct type, whereas `std::array{1, 2, 3}` yields an
/// instance of `std::array<int, 3>`. This, in turn, allows inferring the rank
/// from the instance (useful in combination with `make_descriptor`), as it alows
/// us to write `function(Extent{1, 2, 3})`. Note that `function({1, 2, 3})` is
/// not valid before C++26 because `{1, 2, 3}` is an initializer list (even if
/// `function` accepts an instance of `Extent`), which does not have a known size
/// at compile time. Second, creating a separate struct for the `Extent` allows
/// additional (static) member functions.
///
/// @tparam RANK The rank (number of spatial dimensions) of the tensor that this
/// extent describes a size of.
///
/// @see TensorDescriptor
/// @see make_descriptor
template <size_t RANK>
struct Extent : std::array<size_t, RANK>
{
using Base = std::array<size_t, RANK>;
// Note: Default constructor inherited from std::array.
/// @brief Construct an extent from an `std::vector`.
///
/// This function can be used to turn an `std::vector` into an `Extent`.
/// Because this code is mainly intended for testing, the vector's size is
/// checked. If its not equal to `RANK`, an exception is thrown.
///
/// @throws std::runtime_error if the size of `extent` is not equal to `RANK`.
static Extent from_vector(const std::vector<size_t>& extent)
{
if(extent.size() != RANK)
{
std::stringstream msg;
msg << "invalid rank! expected: " << RANK << ", got: " << extent.size();
throw std::runtime_error(msg.str());
}
Extent result;
std::copy_n(extent.begin(), RANK, result.begin());
return result;
}
// Note: std::array doesn't like generating indexing code when the RANK
// is zero. Looks like there is a missing __device__ overload in ROCm 7.1
// at least. Its not terribly important, but just override the default
// operator[] to fix it.
/// @brief Array indexing operator
///
/// `std::array` has issues with this operator when RANK=0, this version
/// fixes that.
///
/// @param i The index to index the array with.
///
/// @see std::array::operator[]
__device__ __host__ size_t operator[](size_t i) const
{
if constexpr(RANK > 0)
{
return Base::operator[](i);
}
else
{
__builtin_unreachable();
}
}
/// @brief Array indexing operator
///
/// `std::array` has issues with this operator when RANK=0, this version
/// fixes that.
///
/// @param i The index to index the array with.
///
/// @see std::array::operator[]
__device__ __host__ size_t& operator[](size_t i)
{
if constexpr(RANK > 0)
{
return Base::operator[](i);
}
else
{
__builtin_unreachable();
}
}
};
// This is a deduction guideline necessary to resolve `Extent{1, 2, 3}` to the
// correct type. This definition is practically the same as that of `std::array`.
template <typename... T>
Extent(T...) -> Extent<sizeof...(T)>;
/// @brief Concept for automatically deriving tensor memory layout.
///
/// A `TensorStridesGenerator` is a type which can be used to automatically
/// derive the strides (memory layout) of a tensor, given the tensor lengths.
/// This is mainly used to avoid manually computing strides.
///
/// Implementors of this concept are required to implement `operator()`,
/// which accepts an instance of `Extent<RANK>` (the tensor lengths) and
/// yields another instance of `Extent<RANK>` (the tensor strides). Note
/// that the returned strides are expected to be "pre-scanned", meaning
/// that the offset in memory of a tensor can be computed as
/// `dot(index * strides)` (where `*` is element-wise multiplication).
///
/// @see TensorDescriptor
/// @see PackedRightLayout
/// @see PackedLeftLayout
template <typename G, int RANK>
concept TensorStridesGenerator = requires(const G& generator, const Extent<RANK>& lengths) {
{ generator(lengths) } -> std::convertible_to<Extent<RANK>>;
};
/// @brief Layout generator where right-most dimension has stride 1 and
/// all dimensions are packed.
///
/// This structure implements a `TensorStridesGenerator` which generates
/// a memory layout which has the right-most dimension equal to 1, and
/// all other strides increase right-to-left as a products of the extent.
/// This corresponds with a row-major layout.
///
/// @see TensorStridesGenerator
/// @see TensorDescriptor
struct PackedRightLayout
{
/// @brief Stride generation implementation.
///
/// This is the main function which implements the stride generation
///
/// @tparam RANK The rank of the tensor.
///
/// @param lengths The lengths of the tensor.
///
/// @returns The tensor's memory layout according to the definition
/// of `PackedRightLayout`.
///
/// @see TensorStridesGenerator
template <size_t RANK>
Extent<RANK> operator()(const Extent<RANK>& lengths) const
{
Extent<RANK> strides = {};
size_t numel = 1;
for(size_t i = RANK; i > 0; --i)
{
strides[i - 1] = numel;
numel *= lengths[i - 1];
}
return strides;
}
};
static_assert(TensorStridesGenerator<PackedRightLayout, 3>,
"PackedRightLayout should be a TensorStridesGenerator!");
/// @brief Layout generator where left-most dimension has stride 1 and
/// all dimensions are packed.
///
/// This structure implements a `TensorStridesGenerator` which generates
/// a memory layout which has the left-most dimension equal to 1, and
/// all other strides increase left-to-right as a products of the extent.
/// This corresponds with a column-major layout.
///
/// @see TensorStridesGenerator
/// @see TensorDescriptor
struct PackedLeftLayout
{
/// @brief Stride generation implementation.
///
/// This is the main function which implements the stride generation
///
/// @tparam RANK The rank of the tensor.
///
/// @param lengths The lengths of the tensor.
///
/// @returns The tensor's memory layout according to the definition
/// of `PackedLeftLayout`.
///
/// @see TensorStridesGenerator
template <size_t RANK>
Extent<RANK> operator()(const Extent<RANK>& lengths) const
{
Extent<RANK> strides = {};
size_t numel = 1;
for(size_t i = 0; i < RANK; ++i)
{
strides[i] = numel;
numel *= lengths[i];
}
return strides;
}
};
static_assert(TensorStridesGenerator<PackedLeftLayout, 3>,
"PackedLeftLayout should be a TensorStridesGenerator!");
/// @brief Type managing tensor data layout in memory.
///
/// This structure describes a tensor in memory. It does not actually hold any
/// reference to memory, it just describes how the memory should be laid out if it
/// were.
///
/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it
/// also includes the data type of the elements of htis tensor. This is mainly to
/// make the descriptor a _complete_ description of a tensor rather than just the
/// dimensions in strides, which helps in reducing clutter in uses of this type.
///
/// @note All strides are still in _elements_.
///
/// @tparam DT The conceptual data type of the tensor elements. This need not be the
/// type that the data is actually stored as in memory.
/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that
/// the tensor covers.
template <DataType DT, size_t RANK>
struct TensorDescriptor
{
// For now, the implementation of this type is based on
// `ck_tile::HostTensorDescriptor`, so that we can prototype without
// reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard
// the use of `ck_tile::HostTensorDescriptor` here as an implementation detail.
/// @brief Tensor extent alias
///
/// This alias represents a std::array which holds tensor dimensions. There is one
/// item for each dimension in the tensor, and each item corresponds with the
/// value for that dimension.
using Extent = ::ck_tile::builder::test::Extent<RANK>;
/// The conceptual data type of the tensor elements. This need not be the type
/// that the data is actually stored as in memory.
constexpr static DataType data_type = DT;
/// The tensor "rank": the number of conceptial spatial dimensions that the
/// tensor covers.
constexpr static size_t rank = RANK;
/// @brief Create a tensor descriptor from lengths and strides.
///
/// @param lengths A sequence of tensor lengths, the conceptial dimensions of
/// the tensor in elements.
/// @param strides A sequence of in-memory strides of the tensor, measured in
/// elements. Each element of `strides`` corresponds to one at the same index
/// in `lengths`, the amount of elements to skip in memory to find the next
/// element along that axis.
TensorDescriptor(const Extent& lengths, const Extent& strides)
: inner_descriptor_(lengths, strides)
{
// TODO: Validation of strides? For now we just delegate the details of the
// construction to the CK Tile HostTensorDescriptor.
}
/// @brief Create a tensor descriptor with lengths and automatic layout.
///
/// This function initializes a tensor descriptor using lengths, and by deriving
/// the memory layout from the layout generator `Generator`. The tensor will be
/// initialized with the strides yielded from `Generator`.
///
/// @tparam Generator The generator type to generate the strides with. For example,
/// `PackedRightLayout` or `PackedLeftLayout`.
///
/// @param lengths A sequence of tensor lengths, the conceptial dimensions of
/// the tensor in elements.
/// @param gen An instance of `Generator` to generate the strides with.
///
/// @see TensorStridesGenerator
/// @see PackedLeftLayout
/// @see PackedRightLayout
template <typename Generator>
requires TensorStridesGenerator<Generator, RANK>
TensorDescriptor(const Extent& lengths, const Generator& gen)
: TensorDescriptor(lengths, gen(lengths))
{
}
/// Query the conceptual dimensions of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Note that the order
/// does *not* correspond with memory layout, query the in-memory strides for that.
///
/// @see get_strides()
Extent get_lengths() const
{
// TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and
// after that this can just be `return lengths_;` (and make it const Extent&).
Extent result;
std::copy_n(inner_descriptor_.get_lengths().begin(), RANK, result.begin());
return result;
}
/// Query the in-memory strides of the tensor.
///
/// @returns A span of tensor dimensions, one for every axis. Each element
/// corresponds directly with the stride in elements at the same index in the
/// tensor dimensions.
///
/// @see get_lengths()
Extent get_strides() const
{
// TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and
// after that this can just be `return strides_;` (and make it const Extent&).
Extent result;
std::copy_n(inner_descriptor_.get_strides().begin(), RANK, result.begin());
return result;
}
/// @brief Compute conceptual tensor size in elements.
///
/// This function returns the size of the tensor in elements. This function only
/// takes the lengths into account, not the strides. In order to allocate memory
/// for the tensor, use `get_element_space_size()`.
///
/// @see get_lengths
/// @see get_element_space_size
size_t get_element_size() const { return inner_descriptor_.get_element_size(); }
/// @brief Compute total tensor space size in elements.
///
/// This function returns the total size of the memory backing a tensor with
/// this descriptor in *elements*, including required extra size for strides.
///
/// @see get_element_space_size_in_bytes()
size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); }
/// @brief Compute total tensor size in bytes.
///
/// This function is like `get_element_space_size()`, except that the returned
/// value is measured in *bytes* rather than *elements*. Use this function for
/// figuring out how much memory needs to be allocated for a particular tensor.
///
/// @see get_element_space_size()
size_t get_element_space_size_in_bytes() const
{
// For now, the backing type is the naive C++-type that represents the data
// type. When we are going to support packed types such as i4 and fp6, this
// is going to become more complicated.
return get_element_space_size() * data_type_sizeof(DT);
}
/// @brief Check if a tensor is packed in memory.
///
/// This function checks whether the tensor memory is "packed", that is, whether
/// all elements are continuous in memory with no gaps.
bool is_packed() const
{
// First sort by stride, then check if they match the scan of the
// sizes.
const auto& lengths = inner_descriptor_.get_lengths();
const auto& strides = inner_descriptor_.get_strides();
std::array<size_t, RANK> indices;
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](auto i, auto j) {
return strides[i] < strides[j];
});
size_t x = 1;
for(size_t i = 0; i < RANK; ++i)
{
if(strides[indices[i]] != x)
return false;
x *= lengths[indices[i]];
}
return true;
}
/// @brief Get a tensor descriptor for the space backing a tensor.
///
/// This function returns a tensor descriptor which represents the buffer space
/// required to a tensor with this descriptor. This is mainly useful to process
/// buffers with functions which normally operate over tensor descriptors. The
/// resulting tensor descriptor describes a 1D tensor with the same number of
/// elements as in the space.
///
/// @see get_element_space_size()
TensorDescriptor<DT, 1> get_space_descriptor() const
{
ck_tile::builder::test::Extent<1> lengths = {this->get_element_space_size()};
ck_tile::builder::test::Extent<1> strides = {1};
return TensorDescriptor<DT, 1>(lengths, strides);
}
private:
ck_tile::HostTensorDescriptor inner_descriptor_;
};
/// @brief Tensor descriptor construction helper.
///
/// This function can be used to create a tensor descriptor. It accepts the same
/// parameters as the constructor of `TensorDescriptor`, that is, a sequence of
/// lengths and a sequence of strides (or a generator to generate the strides).
/// The main use of this function is that it allows automatic inference of the `RANK`
/// parameter. C++ constructors do not allow partial specification of type parameters,
/// and so its impossible to write `TensorDescriptor<DT> x(Extent{1, 2, 3}, ...)`
/// and have the `RANK` be automatically inferred. Functions do allow this though,
/// so this function can be used to write `make_descriptor(Extent{1, 2, 3}, ...)`
///
/// @tparam DT The conceptual data type of the tensor elements. This need not be the
/// type that the data is actually stored as in memory.
/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that
/// the tensor covers.
///
/// @param lengths A sequence of tensor lengths, the conceptial dimensions of
/// the tensor in elements.
/// @param strides A sequence of in-memory strides of the tensor, or a generator
/// to generate those strides from the tensor lengths.
///
/// @see TensorDescriptor
template <DataType DT, size_t RANK>
TensorDescriptor<DT, RANK> make_descriptor(const Extent<RANK>& lengths, const auto& strides)
{
return TensorDescriptor<DT, RANK>(lengths, strides);
}
/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor.
///
/// This function is similar to `alloc_buffer()`, except that the required size is
/// derived automatically from a tensor descriptor. The returned buffer is valid for
/// tensors with that layout. Strides are also taken into account when computing the
/// required size.
///
/// @tparam DT The conceptual datatype of the elements of the tensor.
/// @tparam RANK The conceptual rank (number of dimensions) of the tensor.
///
/// @param descriptor A descriptor of the memory layout of the tensor to allocate.
///
/// @throws OutOfDeviceMemoryError if memory allocation failed.
///
/// @see TensorDescriptor
/// @see DeviceBuffer
/// @see OutOfDeviceMemoryError
/// @see hipMalloc()
template <DataType DT, size_t RANK>
DeviceBuffer alloc_tensor_buffer(const TensorDescriptor<DT, RANK>& descriptor)
{
return alloc_buffer(descriptor.get_element_space_size_in_bytes());
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,258 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include <cstdint>
#include <concepts>
#include <array>
/// This file implements a generic GPU tensor "foreach" function. This
/// functionality turned out useful in separate parts of the testing
/// system, hence its implemented in a separate file. This version is
/// not particularly efficient (but it should at least be readable),
/// but it should be easy to replace the implementation in the future,
/// should that be needed.
namespace ck_tile::builder::test {
/// @brief Concept for constraining tensor iteration functors.
///
/// This concept checks that a functor has the correct signature for
/// use with the `tensor_foreach` function.
template <typename F, int RANK>
concept ForeachFunctor = requires(const F& f, const Extent<RANK>& index) {
{ f(index) } -> std::same_as<void>;
};
namespace detail {
/// @brief Default foreach kernel block size
///
/// This value is the default number of threads in each block when
/// executing the foreach kernel. This value is mostly arbitrary,
/// 256 is usually a good default for AMD GPUs.
///
/// @see tensor_foreach
constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256;
/// @brief Tensor iteration kernel
///
/// This kernel implements the actual iteration logic, and is intended
/// to be used solely by `tensor_foreach` to iterate & invoke the
/// actual callback.
///
/// @tparam BLOCK_SIZE The number of threads in each block on the GPU.
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to
/// iterate.
/// @tparam F The type of the callback to invoke. This function must be
/// compatible with execution as a __device__ function.
///
/// @param numel The total number of elements in the tensor.
/// @param shape_scan A right-exclusive scan of the shape of the tensor.
/// @param f The callback to invoke for each index of the tensor. This
/// functor must be eligible for running on the GPU.
template <int BLOCK_SIZE, size_t RANK, typename F>
requires ForeachFunctor<F, RANK>
__global__ __launch_bounds__(BLOCK_SIZE) //
void foreach_kernel(const size_t numel, Extent<RANK> shape_scan, F f)
{
const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x;
for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE)
{
// Compute the current index.
Extent<RANK> index = {};
size_t idx = flat_idx;
for(size_t i = 0; i < RANK; ++i)
{
const auto scanned_dim = shape_scan[i];
index[i] = idx / scanned_dim;
idx %= scanned_dim;
}
// Then invoke the callback with the index.
f(index);
}
}
/// @brief A utility to get a C++ type for a CKB type
///
/// Right now this is just an alias of an internal CKB helper,
/// but this should probably be moved elsewhere.
template <builder::DataType DT>
using cpp_type_t = typename builder::factory::internal::DataTypeToCK<DT>::type;
} // namespace detail
/// @brief Calculate tensor memory offset given index and strides.
///
/// This function returns the offset in memory in a tensor, given a particular
/// multi-dimensional index and a particular set of strides. Each value in the
/// index corresponds one-to-one with a value in the strides, which are the
/// index and stride at that dimension in the tensor. These strides must be
/// pre-scanned, meaning that each index is the absolute stride of elements
/// along that axis. In essence, this means that you should pass the output of
/// `TensorDescriptor::get_strides()` into this function.
///
/// @pre The index must be inside the tensor space.
///
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param index A multi-dimensional index inside the tensor space.
/// @param strides A set of strides, one for each dimension.
///
/// @see TensorDescriptor
template <size_t RANK>
__host__ __device__ size_t calculate_offset(const Extent<RANK>& index, const Extent<RANK>& strides)
{
size_t offset = 0;
#pragma unroll
for(size_t i = 0; i < RANK; ++i)
{
offset += index[i] * strides[i];
}
return offset;
}
/// @brief Invoke a callback on the GPU for every index in a tensor.
///
/// This function invokes a callback functor on the GPU, for each index in
/// a tensor. This function _only_ takes care of iterating over all indices
/// in a tensor of a particular shape; this function does not handle or know
/// about actual tensor data.
///
/// @note This function is currently implemented relatively naively: The
/// iteration order is always row-wise, implemented as a persistent kernel.
/// The main objective of this function is to be used with the CK-Builder
/// testing system, and so readability and correctness should be preferred
/// over performance. If this is ever a source of performance problems,
/// feel free to replace the implementation with something better.
///
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param shape The shape of the tensor to iterate over.
/// @param f The callback to invoke for each index of the tensor. This
/// functor must be eligible for running on the GPU.
///
/// @see ForeachFunctor
/// @see detail::foreach_kernel
template <size_t RANK>
void tensor_foreach(const Extent<RANK>& shape, ForeachFunctor<RANK> auto f)
{
constexpr int block_size = detail::DEVICE_FOREACH_BLOCK_SIZE;
const auto kernel = detail::foreach_kernel<block_size, RANK, decltype(f)>;
int occupancy;
check_hip(hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, block_size, 0));
int device;
check_hip(hipGetDevice(&device));
int multiprocessors;
check_hip(
hipDeviceGetAttribute(&multiprocessors, hipDeviceAttributeMultiprocessorCount, device));
// Pre-scan the shape to help indexing in the kernel.
// Note: the order is not that important, so long as the iteration
// order in the kernel is from large-to-small. Right layout is the
// easiest solution for that.
Extent<RANK> shape_scan;
size_t numel = 1;
for(int i = RANK; i > 0; --i)
{
shape_scan[i - 1] = numel;
numel *= shape[i - 1];
}
// Reset any errors from previous launches.
(void)hipGetLastError();
kernel<<<occupancy * multiprocessors, block_size>>>(numel, shape_scan, f);
check_hip(hipGetLastError());
}
/// @brief Concept for tensor initializing functors.
///
/// This concept checks that a functor has the correct signature for
/// use with the `fill_tensor` function.
template <typename F, builder::DataType DT, size_t RANK>
concept FillTensorFunctor = requires(const F& f, const Extent<RANK>& index) {
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
};
/// @brief Utility for initializing tensors.
///
/// This function is a utility helper for initializing tensors. It accepts a
/// tensor descriptor, buffer, and a callback. The callback is invoked for every
/// coordinate (which is passed to the callback), and the tensor is initialized
/// with resulting value.
///
/// @tparam DT The tensor element datatype
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param desc The descriptor of the tensor to initialize.
/// @param buffer The memory of the tensor to initialize.
/// @param f A functor used to get the value at a particular coordinate.
///
/// @see FillTensorFunctor
template <builder::DataType DT, size_t RANK>
void fill_tensor(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
FillTensorFunctor<DT, RANK> auto f)
{
const auto strides = desc.get_strides();
tensor_foreach(desc.get_lengths(), [buffer, f, strides](const auto& index) {
using T = detail::cpp_type_t<DT>;
auto* ptr = static_cast<T*>(buffer);
const auto offset = calculate_offset(index, strides);
ptr[offset] = f(index);
});
}
/// @brief Concept for tensor buffer initializing functors.
///
/// This concept checks that a functor has the correct signature for
/// use with the `fill_tensor_buffer` function.
template <typename F, builder::DataType DT>
concept FillTensorBufferFunctor = requires(const F& f, size_t index) {
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
};
/// @brief Utility for initializing tensor buffers.
///
/// This function is a utility for initializing memory backing a tensor buffer. In
/// contrast to `fill_tensor`, this function first extracts the backing space of
/// the tensor, and then invokes the callback for each (flat) index. This function
/// is particular useful for initializing out-of-bounds indices with a known with a
/// known value.
///
/// @tparam DT The tensor element datatype
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param desc The descriptor of the tensor to initialize.
/// @param buffer The memory of the tensor to initialize.
/// @param f A functor used to get the value at a particular index.
///
/// @see FillTensorBufferFunctor
template <builder::DataType DT, size_t RANK>
void fill_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
FillTensorBufferFunctor<DT> auto f)
{
fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); });
}
template <builder::DataType DT, size_t RANK>
void clear_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
detail::cpp_type_t<DT> value = detail::cpp_type_t<DT>{0})
{
fill_tensor_buffer(desc, buffer, [value]([[maybe_unused]] size_t i) { return value; });
}
} // namespace ck_tile::builder::test

View File

@@ -19,15 +19,30 @@
namespace ck_tile::builder::test {
template <DataType DT>
void init_tensor_buffer_uniform_int(const DeviceBuffer& buf,
const TensorDescriptor<DT>& descriptor,
int min_val,
int max_val)
/// @brief Initialize tensor data with a uniform int distribution
///
/// This function initializes a tensor's device memory with random integer data,
/// drawn from a uniform distribution. The initialization is done directly on the
/// GPU. Note that the entire buffer is filled with the specified distribution
/// regardless of whether the layout is packed.
///
/// @tparam DT The data type of the tensor memory to initialize
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param buf The device memory to initialize
/// @param descriptor A tensor descriptor describing the precise layout of the
/// tensor memory.
/// @param min_value The minimum value of the distribution (inclusive).
/// @param max_value The maximum value of the distribution (exclusive).
template <DataType DT, size_t RANK>
void init_tensor_buffer_uniform_int(void* buf,
const TensorDescriptor<DT, RANK>& descriptor,
int min_value,
int max_value)
{
size_t size = descriptor.get_element_space_size_in_bytes();
if(max_val - min_val <= 1)
if(max_value - min_value <= 1)
{
throw std::runtime_error("Error while filling device tensor with random integer data: max "
"value must be at least 2 greater than min value, otherwise "
@@ -38,19 +53,34 @@ void init_tensor_buffer_uniform_int(const DeviceBuffer& buf,
// we might be asked to generate int values on fp data types that don't have the required
// precision
if(static_cast<ck_type>(max_val - 1) == static_cast<ck_type>(min_val))
if(static_cast<ck_type>(max_value - 1) == static_cast<ck_type>(min_value))
{
throw std::runtime_error("Error while filling device tensor with random integer data: "
"insufficient precision in specified range");
}
size_t packed_size = ck::packed_size_v<ck_type>;
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
static_cast<ck_type>(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type));
static_cast<ck_type>(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type));
}
template <DataType DT>
void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf,
const TensorDescriptor<DT>& descriptor,
/// @brief Initialize tensor data with a uniform float distribution
///
/// This function initializes a tensor's device memory with random floating data,
/// drawn from a uniform distribution. The initialization is done directly on the
/// GPU. Note that the entire buffer is filled with the specified distribution
/// regardless of whether the layout is packed.
///
/// @tparam DT The data type of the tensor memory to initialize
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param buf The device memory to initialize
/// @param descriptor A tensor descriptor describing the precise layout of the
/// tensor memory.
/// @param min_value The minimum value of the distribution (inclusive).
/// @param max_value The maximum value of the distribution (exclusive).
template <DataType DT, size_t RANK>
void init_tensor_buffer_uniform_fp(void* buf,
const TensorDescriptor<DT, RANK>& descriptor,
float min_value,
float max_value)
{
@@ -59,15 +89,30 @@ void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf,
using ck_type = factory::internal::DataTypeToCK<DT>::type;
size_t packed_size = ck::packed_size_v<ck_type>;
fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast<ck_type*>(buf.get()),
fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast<ck_type*>(buf),
min_value,
max_value,
(size * packed_size) / sizeof(ck_type));
}
template <DataType DT>
void init_tensor_buffer_normal_fp(const DeviceBuffer& buf,
const TensorDescriptor<DT>& descriptor,
/// @brief Initialize tensor data with a normal float distribution
///
/// This function initializes a tensor's device memory with random floating data,
/// drawn from a normal distribution. The initialization is done directly on the
/// GPU. Note that the entire buffer is filled with the specified distribution
/// regardless of whether the layout is packed.
///
/// @tparam DT The data type of the tensor memory to initialize
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param buf The device memory to initialize
/// @param descriptor A tensor descriptor describing the precise layout of the
/// tensor memory.
/// @param sigma The standard deviation of the distribution.
/// @param mean The mean of the distribution.
template <DataType DT, size_t RANK>
void init_tensor_buffer_normal_fp(void* buf,
const TensorDescriptor<DT, RANK>& descriptor,
float sigma,
float mean)
{
@@ -76,7 +121,7 @@ void init_tensor_buffer_normal_fp(const DeviceBuffer& buf,
using ck_type = factory::internal::DataTypeToCK<DT>::type;
size_t packed_size = ck::packed_size_v<ck_type>;
fill_tensor_norm_rand_fp_values<<<256, 256>>>(
static_cast<ck_type*>(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type));
static_cast<ck_type*>(buf), sigma, mean, (size * packed_size) / sizeof(ck_type));
}
} // namespace ck_tile::builder::test

View File

@@ -5,6 +5,8 @@
#include <concepts>
#include "ck_tile/builder/testing/validation.hpp"
/// This file is the main header for the CK-Builder testing system. A high-level
/// description of this testing system is documented in
/// `ck_tile/builder/testing/README.md`. This file deals mainly deals with the
@@ -78,7 +80,7 @@ namespace ck_tile::builder::test {
/// that this structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
template <auto SIGNATURE>
struct Args;
@@ -98,7 +100,7 @@ struct Args;
/// structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
template <auto SIGNATURE>
struct Inputs;
@@ -118,7 +120,7 @@ struct Inputs;
/// structure is an aggregrate so that it can be initialized using C++20
/// designated initializers to keep the tests readable.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
template <auto SIGNATURE>
struct Outputs;
@@ -133,7 +135,7 @@ struct Outputs;
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each input tensor.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @see alloc_inputs()
/// @see ValidUniqueInputs
@@ -152,7 +154,7 @@ struct UniqueInputs;
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each output tensor.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @see alloc_outputs()
/// @see ValidUniqueOutputs
@@ -195,7 +197,9 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
///
/// @see Inputs
/// @see UniqueInputs
@@ -208,16 +212,21 @@ UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
/// @brief Allocate inputs corresponding to a signature.
///
/// The `init_inputs()` function is used to initialize pseudo-random data
/// to the tensors specified in the Inputs structure.
/// to the tensors specified in the Inputs structure. Implementors should
/// fill each of the tensors in `inputs` with appropriate random data.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param inputs The operation inputs to initialize with random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Inputs
/// @see UniqueInputs
/// @see tensor_initialization
template <auto SIGNATURE>
requires ValidUniqueInputs<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs);
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete;
/// @brief Allocate outputs corresponding to a signature.
///
@@ -226,7 +235,12 @@ void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs);
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Outputs
/// @see UniqueOutputs
@@ -234,7 +248,34 @@ void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs);
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueOutputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
/// @brief Compare device operation outputs.
///
/// This function implements the main comparison functionality, used to compare
/// the output of one implementation for a particular `SIGNATURE` with that of
/// another. Usually, the `expected` output should be computed by a reference
/// implementation.
///
/// The implementation of this function generates a "report", which includes
/// detailed information about which tensors are different, how many elements
/// were incorrect, and where (a subset of) those elements are located within
/// the tensor. See `ValidationReport` for more information about the report.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param actual The actual results, the results of the operation to-be-tested.
/// @param expected The expected results, the results of the reference implementation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see ValidationReport
template <auto SIGNATURE>
ValidationReport validate(const Args<SIGNATURE>& args,
Outputs<SIGNATURE> actual,
Outputs<SIGNATURE> expected) = delete;
/// @brief Invoke a device operation created by CK Builder.
///
@@ -257,7 +298,7 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
/// @post The tensors in `outputs` are overwritten with the outputs of the device
/// operation.
///
/// @tparam SIGNATURE the signature to specialize this function for
/// @tparam SIGNATURE The signature to specialize this function for
/// @tparam Operation the kernel of the operation to invoke. This type should be
/// one that is created using the Builder API.
/// @param operation An instance of the operation to invoke.
@@ -265,10 +306,13 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
/// @param inputs The input tensor data. Will not be modified by this function.
/// @param outputs The output tensor data. The contents will be overwritten by
/// this function.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
template <auto SIGNATURE, typename Operation>
void run(Operation& operation,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs);
const Outputs<SIGNATURE>& outputs) = delete;
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/error.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/type_convert.hpp"
#include <string_view>
#include <vector>
#include <algorithm>
#include <functional>
#include <bit>
/// This file implements functionality related to "validation", ie, functionality
/// to compare tensors. The functionality in this file should be testing-framework
/// agnostic, and it should NOT generate any error messages by itself. Instead,
/// all relevant information should be stored in the `ValidationReport` structure.
/// This structure should then be used to generate error messages, explainations,
/// etc, by the actual testing framework that the user has chosen.
namespace ck_tile::builder::test {
/// @brief Information about how a set of comparisons failed or succeeded.
///
/// This structure represents a "report" generated by comparing sets of tensors.
/// Its intended to be used as the result of `ckt::validate()`, where `check()`
/// is invoked for each of the output tensors of a particular device operation.
/// The test should be considered successful if _all_ of those checks passes,
/// which can inspected by asserting that `get_errors().size()` is 0.
struct ValidationReport
{
/// @brief Information related to a single tensor comparison.
///
/// This structure holds the information about the result of comparing
/// two particular tensors.
struct Case
{
/// The name of the tensor that was compared here, stored here for convenience
/// so that reporting any errors is easier.
std::string tensor_name;
/// The number of elements which were different between the two compared tensors.
uint64_t wrong_elements;
/// The total number of elements in each tensor.
uint64_t total_elements;
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
/// or an issue with the testing framework. For that reason we also consider that
/// a failure.
bool is_all_zero() const { return zero_elements == total_elements; }
/// @brief Return whether the check associated to this case was successful.
///
/// This function returns whether the check associated to this case was successful,
/// which is directly derived from checking whether the number of incorrect elements
/// was 0 AND whether the tensor was not all zero.
bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); }
};
/// @brief Get comparison cases which were incorrect.
///
/// This function returns a vector of comparison cases that did not succeed, ie, for
/// which `Case::is_ok` return false. In order to check whether validation passed, it
/// is sufficient to assert that this function returns no cases.
std::vector<Case> get_errors() const
{
std::vector<Case> errors;
std::copy_if(reports_.begin(),
reports_.end(),
std::back_inserter(errors),
[](const auto& report) { return !report.is_ok(); });
return errors;
}
/// @brief Compare two tensors and record the results in the report.
///
/// This is the main function used to compare two tensors. The results of this
/// comparison, including any supplemental information, is recorded into the report.
///
/// @returns `false` if the comparison failed. If so, the details can be found via
/// `get_errors()`.
///
/// @tparam DT The data type of the tensors to check.
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to check.
///
/// @param tensor_name The name of the tensors to check. This should be a value by which
/// whoever is debugging the associated test later can easily find out which of the
/// outputs of a device operation was incorrect.
/// @param descriptor The descriptor (memory layout) of the tensor.
/// @param actual The device buffer with the values of the tensor to-be-tested, ie, the
/// results of the device operation.
/// @param expected The device buffer with the values of the reference tensor. These are
/// treated as a "golden standard", and should usually be generated by a reference
/// implementation.
/// @param rtol The relative acceptable tolerance between two values.
/// @param atol The absolute acceptable tolerance between two values.
template <DataType DT, size_t RANK>
bool check(std::string_view tensor_name,
const TensorDescriptor<DT, RANK>& descriptor,
const void* actual,
const void* expected,
double rtol = 1e-3,
double atol = 1e-3);
private:
std::vector<Case> reports_;
};
template <DataType DT, size_t RANK>
bool ValidationReport::check(std::string_view tensor_name,
const TensorDescriptor<DT, RANK>& descriptor,
const void* actual_data,
const void* expected_data,
double rtol,
double atol)
{
const auto strides = descriptor.get_strides();
// During development and CI, only the kernels that were changed would fail, and so we can
// assume that the average case does not have errors. Therefore, split out testing into a
// quick test which just counts the incorrect elements, and a more in-depth test that also
// returns the indices of the incorrect items.
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
const auto* actual = static_cast<const CKType*>(actual_data);
const auto* expected = static_cast<const CKType*>(expected_data);
static_assert(!std::is_same_v<CKType, double>,
"TODO implement compare_kernel() for double");
const auto offset = calculate_offset(index, strides);
const auto a = actual[offset];
const auto b = expected[offset];
const auto o = static_cast<double>(type_convert<float>(a));
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
// for now.
atomicAdd(d_error_count, 1);
}
// Now compare the numbers as bitwise too.
// Update the counter if they're both zero.
using Bytes = std::array<std::byte, sizeof(CKType)>;
bool all_zero = true;
for(auto x : std::bit_cast<Bytes>(a))
{
if(x != std::byte{0})
all_zero = false;
}
for(auto x : std::bit_cast<Bytes>(b))
{
if(x != std::byte{0})
all_zero = false;
}
if(all_zero)
{
atomicAdd(d_zero_count, 1);
}
});
uint64_t error_count = 0;
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
reports_.push_back(Case{
.tensor_name = std::string(tensor_name),
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
});
return reports_.back().is_ok();
}
} // namespace ck_tile::builder::test

View File

@@ -81,33 +81,36 @@ add_ck_builder_test(test_ckb_conv_builder
test_instance_traits_util.cpp
unit_device_buffer.cpp
unit_tensor_descriptor.cpp
unit_tensor_foreach.cpp
unit_error.cpp
unit_validation.cpp
unit_conv_elementwise_op.cpp
unit_conv_tensor_layout.cpp
unit_conv_tensor_type.cpp
unit_conv_thread_block.cpp
unit_conv_tuning_params.cpp)
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
# GPU reference validation tests (in validation/ folder)
# 1. Reference kernel execution and InstanceTraits
add_ck_builder_test(test_ckb_reference_execution
validation/test_reference_execution.cpp
validation/test_reference_instance_traits.cpp)
target_link_libraries(test_ckb_reference_execution PRIVATE utility)
# Note: Optimized kernel validation tests will be added after merging dev branch
# with kernel Run() implementation from colleague's work
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
# GPU reference validation tests (in validation/ folder)
# 1. Reference kernel execution and InstanceTraits
add_ck_builder_test(test_ckb_reference_execution
validation/test_reference_execution.cpp
validation/test_reference_instance_traits.cpp)
target_link_libraries(test_ckb_reference_execution PRIVATE utility)
# Note: Optimized kernel validation tests will be added after merging dev branch
# with kernel Run() implementation from colleague's work
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
################################################################################
# REGRESSION TESTS - Integration Tests (With Kernel Compilation)
################################################################################

View File

@@ -22,7 +22,7 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3,
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
.with_gemm_pipeline(ckb::PipelineVersion::V1);

View File

@@ -5,12 +5,16 @@
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using ck_tile::test::MatchesReference;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
@@ -31,6 +35,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
@@ -78,11 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
.cde_elementwise_op = {},
};
auto inputs = alloc_inputs(args);
auto outputs = alloc_outputs(args);
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
init_inputs(args, inputs);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
auto ref_conv = Reference{};
ckt::run(ref_conv, args, inputs.get(), reference.get());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -40,7 +40,6 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",

View File

@@ -40,7 +40,6 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",

View File

@@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",

View File

@@ -610,6 +610,32 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvSpecializationBwdWeight_,
MultipleDSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_,
GemmBatchOptions_,
TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<4>,
ConvSpecializationBwdWeight_,
GridGemm_,
Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WarpGemm_,

View File

@@ -81,7 +81,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;

View File

@@ -184,7 +184,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;

View File

@@ -161,8 +161,9 @@ struct DefaultAlgorithm
ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler =
ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);

View File

@@ -795,7 +795,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;

View File

@@ -5,8 +5,7 @@
#include "testing_utils.hpp"
namespace ck_tile::builder {
namespace {
using ck_tile::test::inlineDiff;
TEST(InlineDiff, simpleColorDiff)
{
@@ -16,8 +15,8 @@ TEST(InlineDiff, simpleColorDiff)
// some easy tests
// you can veryfy the ungodly strings are meaningful by running echo -e "<string>"
EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello");
EXPECT_THAT(test::inlineDiff(str1, str3, true),
EXPECT_THAT(inlineDiff(str1, str2, true), "hello");
EXPECT_THAT(inlineDiff(str1, str3, true),
"[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]");
}
@@ -28,8 +27,8 @@ TEST(InlineDiff, noColorDiff)
std::string str3{"world"};
// some easy tests without color
EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello");
EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]");
EXPECT_THAT(inlineDiff(str1, str2, false), "hello");
EXPECT_THAT(inlineDiff(str1, str3, false), "[wor|hel]l[d|o]");
}
TEST(InlineDiff, complexColorDiff)
@@ -42,11 +41,8 @@ TEST(InlineDiff, complexColorDiff)
"this part has degeahc, this part has, this part added, this part has ana extra letter"};
EXPECT_THAT(
test::inlineDiff(str5, str4, true),
inlineDiff(str5, str4, true),
"this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m "
"been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], "
"this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter");
};
} // namespace
} // namespace ck_tile::builder

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
#include "ck_tile/builder/testing/testing.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
@@ -21,6 +22,16 @@
/// dedicated function to override to provide printing support.
std::ostream& operator<<(std::ostream& os, hipError_t status);
namespace ck_tile::builder::test {
template <auto SIGNATURE>
std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs<SIGNATURE> outputs)
{
return os << "<tensor outputs>";
}
} // namespace ck_tile::builder::test
namespace ck_tile::test {
static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); }
@@ -150,4 +161,47 @@ struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
/// @param error The error to expect.
::testing::Matcher<hipError_t> HipError(hipError_t error);
template <auto SIGNATURE>
struct ReferenceOutputMatcher
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
{
ReferenceOutputMatcher(const builder::test::Args<SIGNATURE>& args,
builder::test::Outputs<SIGNATURE> expected)
: args_(&args), expected_(expected)
{
}
bool MatchAndExplain(builder::test::Outputs<SIGNATURE> actual,
[[maybe_unused]] ::testing::MatchResultListener* listener) const override
{
const auto report = ck_tile::builder::test::validate(*args_, actual, expected_);
const auto errors = report.get_errors();
if(listener->IsInterested() && !errors.empty())
{
*listener << errors.size() << " tensors failed to validate";
}
return errors.empty();
}
void DescribeTo(std::ostream* os) const override { *os << "<tensor outputs>"; }
void DescribeNegationTo(std::ostream* os) const override
{
*os << "isn't equal to <tensor outputs>";
}
const builder::test::Args<SIGNATURE>* args_;
builder::test::Outputs<SIGNATURE> expected_;
};
template <auto SIGNATURE>
::testing::Matcher<builder::test::Outputs<SIGNATURE>>
MatchesReference(const builder::test::Args<SIGNATURE>& args,
builder::test::Outputs<SIGNATURE> expected)
{
return ::testing::MakeMatcher(new ReferenceOutputMatcher<SIGNATURE>(args, expected));
}
} // namespace ck_tile::test

View File

@@ -11,40 +11,27 @@ namespace {
namespace ckb = ck_tile::builder;
using ck_tile::builder::factory::internal::DataTypeToCK;
TEST(ConvTensorType, AssignsTypesForFP16)
{
using CKType = DataTypeToCK<ckb::DataType::FP16>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::half_t>));
}
template <ckb::DataType DT, typename T>
constexpr auto check_same = std::is_same_v<typename DataTypeToCK<DT>::type, T>;
TEST(ConvTensorType, AssignsTypesForBF16)
TEST(ConvTensorType, Exhaustive)
{
using CKType = DataTypeToCK<ckb::DataType::BF16>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::bhalf_t>));
}
using enum ckb::DataType;
TEST(ConvTensorType, AssignsTypesForFP32)
{
using CKType = DataTypeToCK<ckb::DataType::FP32>::type;
EXPECT_TRUE((std::is_same_v<CKType, float>));
}
TEST(ConvTensorType, AssignsTypesForINT32)
{
using CKType = DataTypeToCK<ckb::DataType::INT32>::type;
EXPECT_TRUE((std::is_same_v<CKType, int32_t>));
}
TEST(ConvTensorType, AssignsTypesForI8)
{
using CKType = DataTypeToCK<ckb::DataType::I8>::type;
EXPECT_TRUE((std::is_same_v<CKType, int8_t>));
}
TEST(ConvTensorType, AssignsTypesForFP8)
{
using CKType = DataTypeToCK<ckb::DataType::FP8>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::f8_t>));
const auto type = FP32;
// This switch ensures that we get a warning (error with -Werror) if
// a variant is missing.
switch(type)
{
case UNDEFINED_DATA_TYPE: break;
case FP32: EXPECT_TRUE((check_same<FP32, float>)); break;
case FP16: EXPECT_TRUE((check_same<FP16, ck::half_t>)); break;
case BF16: EXPECT_TRUE((check_same<BF16, ck::bhalf_t>)); break;
case INT32: EXPECT_TRUE((check_same<INT32, uint32_t>)); break;
case FP8: EXPECT_TRUE((check_same<FP8, ck::f8_t>)); break;
case I8: EXPECT_TRUE((check_same<I8, int8_t>)); break;
case U8: EXPECT_TRUE((check_same<U8, uint8_t>)); break;
}
}
} // namespace

View File

@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} block_gemm;
} block_gemm_pipeline;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GridwiseGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gridwise_gemm;
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();

View File

@@ -2,10 +2,11 @@
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
#include <array>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
@@ -54,6 +55,11 @@ TEST(DeviceBuffer, AutoFree)
// Trying to use a pointer after freeing should return en error in HIP.
EXPECT_THAT(hipMemset(ptr, 0xFF, size), HipError(hipErrorInvalidValue));
// Reset internal HIP error state.
// Otherwise, the error may leak into other tests, triggering anything that
// checks the output of hipGetLastError();
(void)hipGetLastError();
}
TEST(DeviceBuffer, ThrowsOnOom)
@@ -62,13 +68,16 @@ TEST(DeviceBuffer, ThrowsOnOom)
auto check = [] { auto buffer = ckt::alloc_buffer(size); };
EXPECT_THAT(check, Throws<ckt::OutOfDeviceMemoryError>());
// Reset internal HIP error state.
// Otherwise, the error may leak into other tests, triggering anything that
// checks the output of hipGetLastError();
(void)hipGetLastError();
}
TEST(DeviceBuffer, AllocTensorBuffer)
{
std::vector<size_t> lengths = {128, 128, 128};
std::vector<size_t> strides = {128 * 128, 128, 1};
ckt::TensorDescriptor<ckb::DataType::FP32> descriptor(lengths, strides);
ckt::TensorDescriptor<ckb::DataType::FP32, 3> descriptor({128, 128, 128}, {128 * 128, 128, 1});
auto buffer = ckt::alloc_tensor_buffer(descriptor);

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/error.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
namespace ckt = ck_tile::builder::test;
using ::testing::AllOf;
using ::testing::HasSubstr;
using ::testing::Throws;
using ::testing::ThrowsMessage;
[[noreturn]] void throw_error() { throw ckt::HipError("test error", hipErrorInvalidValue); }
TEST(HipError, SourceInfo)
{
EXPECT_THAT(throw_error,
ThrowsMessage<ckt::HipError>(AllOf(
// The error message should include...
// ...the user message
HasSubstr("test error"),
// ...the HIP message
HasSubstr("invalid argument"),
// ...the HIP status code,
HasSubstr("(1)"),
// ...the filename
HasSubstr("experimental/builder/test/unit_error.cpp"),
// ...the function name
HasSubstr("throw_error")
// Note: Don't include the row/column so that we can move
// stuff around in this file.
)));
}
TEST(CheckHip, BasicUsage)
{
EXPECT_THAT([] { ckt::check_hip(hipSuccess); }, Not(Throws<ckt::HipError>()));
EXPECT_THAT([] { ckt::check_hip(hipErrorNotMapped); }, Throws<ckt::HipError>());
EXPECT_THAT([] { ckt::check_hip(hipErrorOutOfMemory); }, Throws<ckt::OutOfDeviceMemoryError>());
EXPECT_THAT([] { ckt::check_hip("test message", hipErrorAlreadyMapped); },
ThrowsMessage<ckt::HipError>(HasSubstr("test message")));
}

View File

@@ -1,25 +1,28 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <array>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::ElementsAreArray;
using ::testing::Ge;
using ::testing::Eq;
using ::testing::Throws;
TEST(TensorDescriptor, Basic)
{
constexpr auto dt = ckb::DataType::FP16;
std::vector<size_t> lengths = {123, 456, 789};
std::vector<size_t> strides = {456 * 789, 789, 1};
constexpr auto dt = ckb::DataType::FP16;
constexpr size_t rank = 3;
ckt::Extent lengths = {123, 456, 789};
ckt::Extent strides = {456 * 789, 789, 1};
ckt::TensorDescriptor<dt> descriptor(lengths, strides);
ckt::TensorDescriptor<dt, rank> descriptor(lengths, strides);
EXPECT_THAT(descriptor.get_lengths(), ElementsAreArray(lengths));
EXPECT_THAT(descriptor.get_strides(), ElementsAreArray(strides));
@@ -27,21 +30,162 @@ TEST(TensorDescriptor, Basic)
TEST(TensorDescriptor, ComputeSize)
{
constexpr auto dt = ckb::DataType::FP32;
std::vector<size_t> lengths = {305, 130, 924};
std::vector<size_t> strides = {1000 * 1000, 1, 1000};
constexpr auto dt = ckb::DataType::FP32;
constexpr size_t rank = 3;
ckt::Extent lengths = {305, 130, 924};
ckt::Extent strides = {1001 * 1000, 1, 1000};
ckt::TensorDescriptor<dt> descriptor(lengths, strides);
ckt::TensorDescriptor<dt, rank> descriptor(lengths, strides);
// Compute the location of the last item in memory, then add one
// to get the minimum size.
size_t expected_size = 1;
// Compute the location of the last item in memory,
// then add one to get the minimum size.
size_t expected_size = 1;
size_t expected_numel = 1;
for(size_t i = 0; i < lengths.size(); ++i)
{
expected_size += (lengths[i] - 1) * strides[i];
expected_numel *= lengths[i];
}
EXPECT_THAT(descriptor.get_element_space_size(), Ge(expected_size));
EXPECT_THAT(descriptor.get_element_size(), Eq(expected_numel));
EXPECT_THAT(descriptor.get_element_space_size(), Eq(expected_size));
EXPECT_THAT(descriptor.get_element_space_size_in_bytes(),
Ge(expected_size * ckt::data_type_sizeof(dt)));
Eq(expected_size * ckt::data_type_sizeof(dt)));
}
TEST(TensorDescriptor, PackedRightLayout)
{
const ckt::Extent lengths = {5125, 623, 1177, 1534};
const auto strides = ckt::PackedRightLayout{}(lengths);
EXPECT_THAT(strides, ElementsAreArray({623 * 1177 * 1534, 1177 * 1534, 1534, 1}));
}
TEST(TensorDescriptor, PackedLeftLayout)
{
const ckt::Extent lengths = {4, 15, 925, 662, 1462};
const auto strides = ckt::PackedLeftLayout{}(lengths);
EXPECT_THAT(strides, ElementsAreArray({1, 4, 4 * 15, 4 * 15 * 925, 4 * 15 * 925 * 662}));
}
TEST(TensorDescriptor, MakeDescriptor)
{
{
const ckt::Extent lengths = {10, 11, 12, 13, 14};
// Note: automatic inference of RANK.
const auto desc =
ckt::make_descriptor<ckb::DataType::INT32>(lengths, ckt::PackedRightLayout{});
EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths));
EXPECT_THAT(desc.get_strides(),
ElementsAreArray({11 * 12 * 13 * 14, 12 * 13 * 14, 13 * 14, 14, 1}));
}
{
const ckt::Extent lengths = {4, 3, 2};
const ckt::Extent strides = {60, 1, 7};
// Note: automatic inference of RANK.
const auto desc = ckt::make_descriptor<ckb::DataType::FP8>(lengths, strides);
EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths));
EXPECT_THAT(desc.get_strides(), ElementsAreArray(strides));
}
}
TEST(TensorDescriptor, GetSpaceDescriptor)
{
{
const auto desc = ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{4, 4, 4},
ckt::PackedLeftLayout{});
const auto space = desc.get_space_descriptor();
const auto expected = 4 * 4 * 4;
EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32));
EXPECT_THAT(decltype(space)::rank, Eq(1));
EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32));
EXPECT_THAT(decltype(space)::rank, Eq(1));
EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected}));
EXPECT_THAT(space.get_strides(), ElementsAreArray({1}));
EXPECT_THAT(space.get_element_size(), Eq(expected));
EXPECT_THAT(space.get_element_space_size(), Eq(expected));
}
{
const ckt::Extent lengths = {6, 3, 4};
const ckt::Extent strides = {102, 1, 2002};
const auto desc = ckt::make_descriptor<ckb::DataType::FP32>(lengths, strides);
const auto space = desc.get_space_descriptor();
// Compute the location of the last item in memory,
// then add one to get the minimum size.
size_t expected_size = 1;
for(size_t i = 0; i < lengths.size(); ++i)
{
expected_size += (lengths[i] - 1) * strides[i];
}
EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32));
EXPECT_THAT(decltype(space)::rank, Eq(1));
EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected_size}));
EXPECT_THAT(space.get_strides(), ElementsAreArray({1}));
EXPECT_THAT(space.get_element_size(), Eq(expected_size));
EXPECT_THAT(space.get_element_space_size(), Eq(expected_size));
}
}
TEST(TensorDescriptor, EmptyExtent)
{
// A rank-0 tensor points to a single element
const auto desc = ckt::make_descriptor<ckb::DataType::FP16>(ckt::Extent{}, ckt::Extent{});
EXPECT_THAT(decltype(desc)::rank, Eq(0));
EXPECT_THAT(desc.get_lengths().size(), Eq(0));
EXPECT_THAT(desc.get_strides().size(), Eq(0));
EXPECT_THAT(desc.get_element_size(), Eq(1));
EXPECT_THAT(desc.get_element_space_size(), Eq(1));
EXPECT_THAT(desc.get_element_space_size_in_bytes(), Eq(2));
// We expect a rank-1 tensor with the one dimension being 1.
const auto space = desc.get_space_descriptor();
const auto expected = 1;
EXPECT_THAT(decltype(space)::rank, Eq(1));
EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected}));
EXPECT_THAT(space.get_strides(), ElementsAreArray({1}));
EXPECT_THAT(space.get_element_size(), Eq(expected));
EXPECT_THAT(space.get_element_space_size(), Eq(expected));
EXPECT_THAT(space.get_element_space_size_in_bytes(), Eq(2));
}
TEST(TensorDescriptor, ExtentFromVector)
{
EXPECT_THAT(ckt::Extent<4>::from_vector(std::vector<size_t>{1, 2, 3, 4}),
ElementsAreArray({1, 2, 3, 4}));
EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector<size_t>{1, 2}); },
Throws<std::runtime_error>());
}
TEST(TensorDescriptor, IsPacked)
{
constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{})
.is_packed());
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{})
.is_packed());
EXPECT_TRUE(ckt::make_descriptor<dt>(ckt::Extent{}, ckt::Extent{}).is_packed());
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725})
.is_packed());
EXPECT_FALSE(
ckt::make_descriptor<dt>(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed());
EXPECT_FALSE(
ckt::make_descriptor<dt>(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed());
}

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <algorithm>
#include <functional>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::Each;
using ::testing::Eq;
TEST(TensorForeach, CalculateOffset)
{
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123));
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{523, 266, 263}, ckt::Extent{1, 545, 10532}),
Eq(2915409));
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{}, ckt::Extent{}), Eq(0));
// Note: >4 GB overflow test
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{8, 2, 5, 7, 0, 4, 1, 3, 6, 9},
ckt::Extent{1'000,
1'000'000,
10'000'000,
1'000'000'000,
1,
10'000,
100,
10,
100'000'000,
100'000}),
Eq(size_t{7'652'948'130}));
}
TEST(TensorForeach, VisitsCorrectCount)
{
// tensor_foreach should visit every index exactly once.
// This test checks that the count is at least correct.
const ckt::Extent shape = {10, 20, 30};
auto d_count = ckt::alloc_buffer(sizeof(uint64_t));
ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t)));
ckt::tensor_foreach(shape, [count = d_count.get()]([[maybe_unused]] const auto& index) {
atomicAdd(reinterpret_cast<uint64_t*>(count), 1);
});
uint64_t actual;
ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost));
const auto expected = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
EXPECT_THAT(actual, Eq(expected));
}
TEST(TensorForeach, VisitsEveryIndex)
{
const ckt::Extent shape = {5, 6, 7, 8, 9, 10, 11};
const auto total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
// We know this is correct due to testing in unit_tensor_descriptor.cpp
const auto stride = ckt::PackedRightLayout{}(shape);
auto d_output = ckt::alloc_buffer(sizeof(uint32_t) * total);
ckt::check_hip(hipMemset(d_output.get(), 0, sizeof(uint32_t) * total));
ckt::tensor_foreach(shape, [output = d_output.get(), stride](const auto& index) {
// We know this is correct due to the CalculateOffset test.
auto offset = ckt::calculate_offset(index, stride);
// Use atomic add so that we can check that every index is visited exactly once.
atomicAdd(&reinterpret_cast<uint32_t*>(output)[offset], 1);
});
std::vector<uint32_t> actual(total);
ckt::check_hip(
hipMemcpy(actual.data(), d_output.get(), sizeof(uint32_t) * total, hipMemcpyDeviceToHost));
EXPECT_THAT(actual, Each(Eq(1)));
}
TEST(TensorForeach, FillTensorBuffer)
{
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(ckt::Extent{31, 54, 13},
ckt::PackedRightLayout{});
auto buffer = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, buffer.get(), [](size_t i) { return static_cast<uint32_t>(i); });
std::vector<uint32_t> h_buffer(desc.get_element_space_size());
ckt::check_hip(hipMemcpy(
h_buffer.data(), buffer.get(), h_buffer.size() * sizeof(uint32_t), hipMemcpyDeviceToHost));
for(size_t i = 0; i < h_buffer.size(); ++i)
{
EXPECT_THAT(h_buffer[i], Eq(static_cast<uint32_t>(i)));
}
}
TEST(TensorForeach, FillTensor)
{
// FillTensor with non-packed indices should not write out-of-bounds.
const ckt::Extent shape = {4, 23, 35};
const ckt::Extent pad = {12, 53, 100};
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
const auto strides = desc.get_strides();
auto size = desc.get_element_space_size();
auto buffer = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, buffer.get(), []([[maybe_unused]] size_t i) { return 123; });
ckt::fill_tensor(desc, buffer.get(), []([[maybe_unused]] const auto& index) { return 1; });
auto d_error = ckt::alloc_buffer(sizeof(uint32_t) * size);
ckt::check_hip(hipMemset(d_error.get(), 0, sizeof(uint32_t)));
ckt::tensor_foreach(
// Iterate over the entire padding so that we can check out-of-bounds elements
pad,
[shape, pad, strides, size, error = d_error.get(), tensor = buffer.get()](
const auto& index) {
const auto offset = ckt::calculate_offset(index, strides);
const auto value = reinterpret_cast<const uint32_t*>(tensor)[offset];
// Note: The space of the descriptor will not actually be (12, 53, 100) but
// more like (4, 53, 100), as the outer stride is irrelevant. So we have to
// perform an extra bounds check here.
if(offset < size)
{
// Check if the coordinate is within the shape bounds.
bool in_bounds = true;
for(size_t i = 0; i < shape.size(); ++i)
{
if(index[i] >= shape[i])
{
in_bounds = false;
}
}
// In-bounds elements are 1, out-of-bounds is 123.
if(in_bounds && value != 1)
{
atomicAdd(reinterpret_cast<uint32_t*>(error), 1);
}
else if(!in_bounds && value != 123)
{
atomicAdd(reinterpret_cast<uint32_t*>(error), 1);
}
}
});
uint32_t error_count = 0;
ckt::check_hip(hipMemcpy(&error_count, d_error.get(), sizeof(uint32_t), hipMemcpyDeviceToHost));
EXPECT_THAT(error_count, Eq(0));
}
TEST(TensorForeach, ClearTensorZeros)
{
const ckt::Extent shape = {5, 4, 5, 4, 5, 4, 5, 6};
const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6};
const auto desc =
ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
auto buffer = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, buffer.get());
// Check that all values are zeroed.
auto d_count = ckt::alloc_buffer(sizeof(uint64_t));
ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t)));
{
const auto size = desc.get_element_space_size();
const auto strides = desc.get_strides();
auto* count = d_count.get();
const auto* tensor = reinterpret_cast<const uint32_t*>(buffer.get());
// Note: iterate over the entire pad, so that we can check out-of-bounds elements.
ckt::tensor_foreach(pad,
[count, tensor, strides, size]([[maybe_unused]] const auto& index) {
const auto offset = ckt::calculate_offset(index, strides);
// Note: The space of the descriptor will not actually be (6, 6,
// ...) but more like (5, 6, ...), as the outer stride is
// irrelevant. So we have to perform an extra bounds check here.
if(offset < size && tensor[offset] != 0)
{
atomicAdd(reinterpret_cast<uint64_t*>(count), 1);
}
});
}
uint64_t actual;
ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost));
EXPECT_THAT(actual, Eq(0));
}

View File

@@ -0,0 +1,298 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/error.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/validation.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <span>
#include <array>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using testing::ElementsAreArray;
using testing::Eq;
using testing::StrEq;
using ck_tile::test::MatchesReference;
using ck_tile::test::StringEqWithDiff;
// Googletest cannot have both type AND value parameterized tests.
// For now just act lazy and use value template parameters.
template <ckb::DataType DT, ckt::Extent SHAPE, auto STRIDES>
struct Param
{
constexpr static auto data_type = DT;
constexpr static auto shape = SHAPE;
constexpr static auto strides = STRIDES;
constexpr static auto rank = shape.size();
static ckt::TensorDescriptor<data_type, rank> get_descriptor()
{
return ckt::make_descriptor<data_type, rank>(shape, strides);
}
};
template <typename Param>
struct ValidationReportTests : public ::testing::Test
{
};
using Types = ::testing::Types<
Param<ckb::DataType::FP32, ckt::Extent{52, 152, 224}, ckt::PackedRightLayout{}>,
Param<ckb::DataType::FP32, ckt::Extent{72, 1, 49, 2, 4, 5}, ckt::PackedLeftLayout{}>,
Param<ckb::DataType::FP32, ckt::Extent{}, ckt::Extent{}>,
Param<ckb::DataType::FP32, ckt::Extent{12, 34, 43, 21}, ckt::Extent{41, 1, 43210, 1831}>>;
TYPED_TEST_SUITE(ValidationReportTests, Types);
TYPED_TEST(ValidationReportTests, SingleCorrect)
{
const auto desc = TypeParam::get_descriptor();
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, a.get());
ckt::clear_tensor_buffer(desc, b.get());
// Generate a sort-of-random looking sequence
auto generator = [strides = desc.get_strides()](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, strides);
return static_cast<float>((flat_index + 1) * 10'000'019 % 768'351);
};
ckt::fill_tensor(desc, a.get(), generator);
ckt::fill_tensor(desc, b.get(), generator);
ckt::ValidationReport report;
report.check("correct", desc, b.get(), a.get());
EXPECT_THAT(report.get_errors().size(), Eq(0));
}
TYPED_TEST(ValidationReportTests, SingleIncorrect)
{
const auto desc = TypeParam::get_descriptor();
const auto packed_strides = ckt::PackedRightLayout{}(desc.get_lengths());
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, a.get());
ckt::clear_tensor_buffer(desc, b.get());
ckt::fill_tensor(desc, a.get(), []([[maybe_unused]] const auto& i) { return 123; });
ckt::fill_tensor(desc, b.get(), [packed_strides](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, packed_strides);
return flat_index == 0 ? 0 : flat_index == 12345 ? 456 : flat_index == 999999 ? 1 : 123;
});
ckt::ValidationReport report;
report.check("incorrect", desc, b.get(), a.get());
const auto errors = report.get_errors();
const auto flat_size = desc.get_element_size();
const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1;
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
}
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
{
const auto desc = TypeParam::get_descriptor();
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, a.get());
ckt::clear_tensor_buffer(desc, b.get());
ckt::ValidationReport report;
report.check("zero_is_incorrect", desc, b.get(), a.get());
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
}
TEST(ValidationReportTests, MultipleSomeIncorrect)
{
ckt::ValidationReport report;
{
auto desc = ckt::make_descriptor<ckb::DataType::BF16, 4>({'R', 'O', 'C', 'm'},
ckt::PackedLeftLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 100); });
ckt::fill_tensor_buffer(
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 101); });
report.check("incorrect 1", desc, b.get(), a.get());
}
{
auto desc =
ckt::make_descriptor<ckb::DataType::U8, 3>({'H', 'I', 'P'}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return "ROCm"[i % 4]; });
ckt::fill_tensor_buffer(desc, b.get(), [](size_t i) {
switch(i % 4)
{
case 0: return 'R';
case 1: return 'O';
case 2: return 'C';
case 3: return 'm';
default: return 'x';
}
});
report.check("correct", desc, b.get(), a.get());
}
{
auto desc = ckt::make_descriptor<ckb::DataType::INT32, 3>({'G', 'P', 'U'},
ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; });
ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; });
report.check("incorrect 2", desc, b.get(), a.get());
}
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1"));
EXPECT_THAT(errors[0].wrong_elements, Eq(46840334));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2"));
EXPECT_THAT(errors[1].wrong_elements, Eq(482800));
}
// MatchesReference operates on the types defined in testing.hpp, so just
// quickly define a bunch of dummy values for that.
struct DummySignature
{
};
constexpr DummySignature DUMMY_SIGNATURE = {};
namespace ck_tile::builder::test {
template <>
struct Args<DUMMY_SIGNATURE>
{
auto make_a_descriptor() const
{
return make_descriptor<builder::DataType::FP32>(Extent{5, 5, 5, 5}, PackedRightLayout{});
}
auto make_b_descriptor() const
{
return make_descriptor<builder::DataType::FP16>(Extent{100000}, PackedLeftLayout{});
}
};
template <>
struct Outputs<DUMMY_SIGNATURE>
{
void* a;
void* b;
};
template <>
ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
Outputs<DUMMY_SIGNATURE> actual,
Outputs<DUMMY_SIGNATURE> expected)
{
ValidationReport report;
report.check("a", args.make_a_descriptor(), actual.a, expected.a);
report.check("b", args.make_b_descriptor(), actual.b, expected.b);
return report;
}
} // namespace ck_tile::builder::test
TEST(MatchesReference, Correct)
{
const ckt::Args<DUMMY_SIGNATURE> args;
const auto a_desc = args.make_a_descriptor();
const auto b_desc = args.make_b_descriptor();
auto a_actual = ckt::alloc_tensor_buffer(a_desc);
auto b_actual = ckt::alloc_tensor_buffer(b_desc);
ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1);
ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2);
const auto actual = ckt::Outputs<DUMMY_SIGNATURE>{
.a = a_actual.get(),
.b = b_actual.get(),
};
auto a_expected = ckt::alloc_tensor_buffer(a_desc);
auto b_expected = ckt::alloc_tensor_buffer(b_desc);
ckt::clear_tensor_buffer(a_desc, a_expected.get(), 1);
ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2);
const auto expected = ckt::Outputs<DUMMY_SIGNATURE>{
.a = a_expected.get(),
.b = b_expected.get(),
};
EXPECT_THAT(actual, MatchesReference(args, expected));
}
TEST(MatchesReference, Incorrect)
{
const ckt::Args<DUMMY_SIGNATURE> args;
const auto a_desc = args.make_a_descriptor();
const auto b_desc = args.make_b_descriptor();
auto a_actual = ckt::alloc_tensor_buffer(a_desc);
auto b_actual = ckt::alloc_tensor_buffer(b_desc);
ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1);
ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2);
const auto actual = ckt::Outputs<DUMMY_SIGNATURE>{
.a = a_actual.get(),
.b = b_actual.get(),
};
auto a_expected = ckt::alloc_tensor_buffer(a_desc);
auto b_expected = ckt::alloc_tensor_buffer(b_desc);
ckt::clear_tensor_buffer(a_desc, a_expected.get(), 2);
ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2);
const auto expected = ckt::Outputs<DUMMY_SIGNATURE>{
.a = a_expected.get(),
.b = b_expected.get(),
};
testing::StringMatchResultListener listener;
EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener));
EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate"));
}

View File

@@ -399,7 +399,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuf
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WarpGemm_>(t))
<< "," << to_string(static_cast<InputOutputTileTransfer_<4>>(t));
<< "," << to_string(static_cast<InputOutputTileTransfer_<>>(t));
return oss.str();
}

View File

@@ -7,7 +7,6 @@
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/device_tensor_generator.hpp"
#include "ck/utility/data_type.hpp"
#include <cmath>
// use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in
// the future
@@ -107,6 +106,7 @@ template <typename T>
__global__ void
fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size)
{
static constexpr float PI = 3.141592653f;
// initial values
ran_state_u32 s = ran_init();
float norm[2];
@@ -115,12 +115,11 @@ fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_e
{
if(j % (2 / ck::packed_size_v<T>) == 0)
{
float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
norm[0] =
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean;
norm[1] =
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean;
float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
float scale = sigma * ck::math::sqrt(-2.0f * ck::math::log(u1));
norm[0] = scale * ck::math::cos(2.0f * PI * u2) + mean;
norm[1] = scale * ck::math::sin(2.0f * PI * u2) + mean;
}
if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)

View File

@@ -10,7 +10,8 @@
namespace ck {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
#define __gfx11__
#endif

View File

@@ -2376,12 +2376,23 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
{
if(src_thread_element_valid)
{
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
else
{
return thread_buffer<T, N>{numeric<T>::zero()};
}
}
else
return tmp;
{
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
#endif
}

View File

@@ -87,6 +87,7 @@ enum struct amdgcn_target_id
GFX1150 = 0x1150,
GFX1151 = 0x1151,
GFX1152 = 0x1152,
GFX1153 = 0x1153,
GFX11_GENERIC = 0x11FF,
GFX1200 = 0x1200,
GFX1201 = 0x1201,
@@ -282,6 +283,7 @@ constexpr auto get_compiler_target()
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
@@ -348,6 +350,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const*
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1153", GFX1153);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201);
@@ -603,6 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
@@ -683,6 +687,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* tes
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1153", GFX1153);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201);
@@ -1119,8 +1124,14 @@ CK_TILE_DEVICE static constexpr auto get_device_arch()
{
// FIXME(0): on all devices except gfx11 it returns gfx12_t
// FIXME(1): during the host compilation pass it returns gfx12_t
#if defined(__gfx11__)
#if defined(__gfx103__)
return gfx103_t{};
#elif defined(__gfx11__)
return gfx11_t{};
#elif defined(__gfx950__)
return gfx950_t{};
#elif defined(__gfx9__)
return gfx9_t{};
#else
return gfx12_t{};
#endif
@@ -1141,26 +1152,10 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
{
#if defined(__gfx103__)
return gfx103_t{};
#elif defined(__gfx11__)
return gfx11_t{};
#elif defined(__gfx12__)
return gfx12_t{};
#elif defined(__gfx950__)
return gfx950_t{};
#elif defined(__gfx9__)
return gfx9_t{};
#else
return gfx_invalid_t{};
#endif
}
} // namespace detail
CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
{
return detail::get_n_lds_banks(detail::arch_tag_dispatch());
return detail::get_n_lds_banks(get_device_arch());
}
enum LLVMSchedGroupMask : int32_t

View File

@@ -315,6 +315,7 @@ namespace ck_tile::core {
* @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102.
* @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151.
* @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152.
* @var CK_TILE_ARCH_GFX1153 Indicates if the compiler target architecture is GFX1153.
* @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic.
* @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200.
* @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201.
@@ -468,6 +469,12 @@ struct amdgcn_compiler_target_state
static constexpr bool CK_TILE_ARCH_GFX1152 = false;
#endif // __gfx1152__
#if defined(__gfx1153__)
static constexpr bool CK_TILE_ARCH_GFX1153 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1153 = false;
#endif // __gfx1153__
#if defined(__gfx11_generic__)
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true;
#else
@@ -538,6 +545,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1153, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \

View File

@@ -34,46 +34,23 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
// For swapped Hs tile case I need only get_rh_minor_to_y
// since rh_major are already swapped due to swapped Hs.
constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<index_t, index_t> rh_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_minor_to_y_(rh_minor) = i;
});
return rh_minor_to_y_;
};
// In swapped Hs case <Y,X> -> <X,Y> tile
// we have same rh_major, but reversed rh_minor!
constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
// Is this really needed?? Should we have simple reverse here??
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
}
static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; });
return y_dim_out_to_in_;
}();
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
constexpr index_t y_dim_vec_out = 0;
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];

View File

@@ -55,9 +55,10 @@ struct FillUniformDistribution
const auto total_bytes = total * sizeof(T_iter);
// max 80 threads; at least 2MB per thread
const size_t available_cpu_cores = get_available_cpu_cores();
const size_t num_thread =
min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
const size_t available_cpu_cores = get_available_cpu_cores();
constexpr uint64_t MAX_THREAD_COUNT = 80;
const size_t num_thread = min(
MAX_THREAD_COUNT, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
constexpr size_t BLOCK_BYTES = 64;
constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter);
const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES);

View File

@@ -3,6 +3,7 @@
#pragma once
#include <cinttypes>
#include <cstdlib>
#include <thread>
@@ -28,7 +29,7 @@ CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
output.get_num_of_dimension() == NDimSpatial + 3))
{
printf("%lu %lu %lu",
printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
input.get_num_of_dimension(),
weight.get_num_of_dimension(),
output.get_num_of_dimension());

View File

@@ -30,7 +30,6 @@ template <typename AsDataType_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
@@ -39,31 +38,30 @@ template <typename AsDataType_,
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
@@ -105,28 +103,27 @@ struct CShuffleEpilogue
ADataType,
BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
@@ -142,8 +139,7 @@ struct CShuffleEpilogue
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
isCTransposed ? "CTransposed" : "CNotTransposed",
mem_op_string<MemoryOperation>());
isCTransposed ? "CTransposed" : "CNotTransposed");
// clang-format on
}
@@ -337,14 +333,30 @@ struct CShuffleEpilogue
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
// this branch is for original a16w4
if constexpr(is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
}
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
@@ -355,7 +367,8 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
}
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
@@ -445,7 +458,8 @@ struct CShuffleEpilogue
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
const COutTensor& c_out_tensor)
{
if constexpr(MemoryOperation == memory_operation_enum::set)
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
@@ -617,7 +631,8 @@ struct CShuffleEpilogue
});
// store/update
if constexpr(MemoryOperation == memory_operation_enum::set)
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}

View File

@@ -15,17 +15,15 @@ template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr index_t NumDTensor = 0;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t NumDTensor = 0;
};
template <typename AsDataType_,
@@ -44,14 +42,9 @@ template <typename AsDataType_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
ODataType_,
kPadM_,
kPadN_,
UseRawStore_,
MemoryOperation_>
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
@@ -81,7 +74,6 @@ struct Default2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
@@ -102,7 +94,10 @@ struct Default2DEpilogue
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(MemoryOperation == memory_operation_enum::set)
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{
@@ -123,7 +118,10 @@ struct Default2DEpilogue
}
else
{
if constexpr(MemoryOperation == memory_operation_enum::set)
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{

View File

@@ -558,21 +558,19 @@ struct FlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
@@ -581,25 +579,81 @@ struct FlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: No padding needed for b_flat
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -625,7 +679,56 @@ struct FlatmmKernel
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -647,98 +750,8 @@ struct FlatmmKernel
}
}();
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
"only support per-tensor or per-row scaling");
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
"only support per-tensor or per-column scaling");
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: splitk_batch_offset.splitted_k /
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(ScaleGranularityKB == 0
? 1
: (splitk_batch_offset.splitted_k /
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
return make_tuple(a_tensor_view,
b_flat_tensor_view,
ds_tensor_view,
e_tensor_view,
scale_m_view,
scale_n_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
@@ -755,93 +768,72 @@ struct FlatmmKernel
}
}();
return make_tuple(a_pad_view,
b_flat_tensor_view,
ds_pad_view,
e_pad_view,
views.at(number<4>{}),
views.at(number<5>{}));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
}
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m)
{
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
auto scale_m_window = make_tile_window(views.at(number<4>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{i_m, 0});
auto scale_n_window = make_tile_window(views.at(number<5>{}),
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_m_window,
scale_n_window);
// Step 1: Create tensor view
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_m_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{block_idx_m, 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_n)
{
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
// Step 1: Create tensor view
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_n_view,
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -857,45 +849,74 @@ struct FlatmmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
auto scale_m_window = gemm_tile_windows.at(number<4>{});
auto scale_n_window = gemm_tile_windows.at(number<5>{});
// Run Epilogue Pipeline
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
@@ -924,8 +945,7 @@ struct FlatmmKernel
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);

View File

@@ -100,21 +100,19 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
@@ -123,25 +121,80 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}();
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: No padding needed for b_flat
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -167,7 +220,56 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -189,70 +291,8 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
}
}();
auto scale_n = kargs.scale_n_ptr;
index_t FlatScaleK =
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
@@ -269,77 +309,37 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
}
auto scale_block_window =
make_tile_window(views.at(I4),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
const index_t block_idx_n)
{
auto scale_n = kargs.scale_n_ptr;
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_block_window);
// Step 1: Create tensor view
index_t FlatScaleK =
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
// Step 2: Create tile window
return make_tile_window(
scale_b_flat_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -355,21 +355,15 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I4);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
@@ -378,6 +372,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
// Run GEMM cooperatively by whole workgroup.
auto a_block_window_with_distr =
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
@@ -390,22 +385,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
@@ -434,8 +453,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);

View File

@@ -1476,7 +1476,8 @@ struct MoeFlatmmKernel
c_scatter_valids[mIter]);
if constexpr(!IsInputGemm ||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::atomic_add)
c_scatter_tile_window.update(c_out_tensor);
else
c_scatter_tile_window.store(c_out_tensor);

View File

@@ -113,32 +113,50 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<MXFlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}();
// Step 2: Create padded view
const auto& a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadK>{});
// Step 3: Create tile window
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view with special flat layout
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
@@ -153,6 +171,22 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
}();
// Step 2: No padding for flat B
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
number<MXFlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -178,7 +212,56 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -200,92 +283,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
auto scale_a = kargs.scale_m_ptr;
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
// A scale tensor view
const auto& scale_a_tensor_view = [&]() {
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
}();
// B scale tensor view
const auto& scale_b_tensor_view = [&]() {
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_navie_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
}();
return make_tuple(a_tensor_view,
b_flat_tensor_view,
ds_tensor_view,
e_tensor_view,
scale_a_tensor_view,
scale_b_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadK>{});
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
@@ -302,79 +301,71 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
return make_tuple(
a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
number<MXFlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs,
const index_t block_idx_m)
{
static constexpr int BlockScaleSize = 32;
auto scale_a_block_window = make_tile_window(
views.at(I4),
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
// Step 1: Create tensor view
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto& scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(kargs.scale_m_ptr.ptr), scale_a_desc);
// Step 2: Create tile window
return make_tile_window(
scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
{i_m / MXdlPack, 0});
{block_idx_m / MXdlPack, 0});
}
auto scale_b_block_window = make_tile_window(
views.at(I5),
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs,
const index_t block_idx_n)
{
static constexpr int BlockScaleSize = 32;
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
// Step 1: Create tensor view
const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto& scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(kargs.scale_n_ptr.ptr), scale_b_desc);
// Step 2: Create tile window
return make_tile_window(
scale_b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
{i_n / NXdlPack, 0});
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_a_block_window,
scale_b_block_window);
{block_idx_n / NXdlPack, 0});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -390,22 +381,16 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m);
const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
@@ -422,22 +407,46 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
// Run Epilogue Pipeline with split_k dispatch
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
@@ -466,27 +475,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
static_assert(false,
"Unimplemented: atomic_add with odd vector size for fp16/bf16");
}
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
// KV cache memory layout selector.
//
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
// - VECTORIZED_LAYOUT (swizzled):
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
// - LINEAR_LAYOUT:
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
enum class BlockAttentionKVCacheMemoryLayoutEnum
{
VECTORIZED_LAYOUT = 0,
LINEAR_LAYOUT = 1,
};
// KV cache lookup table layout selector.
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
enum class BlockAttentionKVCacheLookupTableEnum
{
VLLM_BLOCK_TABLE_2D = 0,
SGLANG_PAGE_TABLE_1D = 1,
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct SglangPageTableKargs
{
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
const int32_t* kv_last_page_lens;
};
struct VllmPageTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
const int32_t* seqlen_k_ptr;
};
using PageBlockTableKargs =
std::conditional_t<kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
SglangPageTableKargs,
VllmPageTableKargs>;
struct FmhaFwdCommonKargs
{
const void* q_ptr;
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nhead_ratio_qk;
int32_t num_total_pages;
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
const int32_t* kv_last_page_lens;
ck_tile::index_t page_block_size;
#else
static constexpr ck_tile::index_t page_block_size = 1;
#endif
PageBlockTableKargs page_table;
float scale_s;
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
#if 0 // we assume page_block_size=1 for now
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
const index_t seqlen_k = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
const int32_t num_page_blocks = page_end - page_start;
const int32_t last_page_len = [&]() {
if constexpr(kPageBlockSize == 1)
return static_cast<int32_t>(kPageBlockSize);
else
return kargs.page_table.kv_last_page_lens[i_batch];
}();
return num_page_blocks > 0
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
last_page_len)
: 0;
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
if(kargs.page_table.seqlen_k_ptr != nullptr)
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
else
return kargs.seqlen_k;
}
}();
const int32_t* page_idx = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
return kargs.page_table.block_table_ptr +
static_cast<long_index_t>(i_batch) *
kargs.page_table.batch_stride_block_table;
}
}();
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
batch_offset_q = query_start * kargs.stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return;
}
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
// for simplicity, batch stride we just modify the pointer
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
// Logical View for Pipeline: (TotalSeqK, D)
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
// PageBlockSize, kVectorSize)
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages,
kargs.hdim_q / kVectorSize,
kargs.page_block_size,
kVectorSize),
make_tuple(
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(
make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
make_tuple(sequence<1>{}, sequence<0>{}),
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size)), // TotalSeqK
make_merge_transform(
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
static_cast<int32_t>(kVectorSize)))), // D
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_transposed,
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (TotalSeqK, D)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, S, D) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
make_pass_through_transform(kargs.hdim_q)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
// Define the naive physical view with 4D shape: (NumPages,
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kargs.hdim_v,
kVectorSize),
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kVectorSize))), // TotalSeqK
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (D, TotalSeqK)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
return pad_tensor_view(
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S, D) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
make_tuple(sequence<2>{}, sequence<0, 1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV_, kPadSeqLenK>{});
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
const index_t stride_k_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kVectorSize
: kargs.stride_k;
const index_t stride_v_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kargs.hdim_v
: kargs.stride_v;
auto o_acc_tile = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
}();

Some files were not shown because too many files have changed in this diff Show More