mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Merge branch 'develop' into users/yiding12/fmha-bwd-workspace
This commit is contained in:
@@ -10,7 +10,7 @@ if(NOT INST_TARGETS)
|
||||
endif()
|
||||
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(BUILD_TESTING)
|
||||
@@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,80,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
@@ -174,12 +173,47 @@ else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the batch_prefill API in fmha_fwd example and tests
|
||||
if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
set(FMHA_HAS_RDNA_TARGET OFF)
|
||||
set(FMHA_HAS_NON_RDNA_TARGET OFF)
|
||||
foreach(inst_target ${INST_TARGETS})
|
||||
if(inst_target MATCHES "^(gfx11|gfx12)")
|
||||
set(FMHA_HAS_RDNA_TARGET ON)
|
||||
else()
|
||||
set(FMHA_HAS_NON_RDNA_TARGET ON)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(FMHA_HAS_RDNA_TARGET)
|
||||
set(FMHA_FWD_RDNA_GEN_BLOBS)
|
||||
foreach(fwd_blob ${FMHA_FWD_GEN_BLOBS})
|
||||
if(fwd_blob MATCHES "_gfx1[12][^/]*\\.cpp$")
|
||||
list(APPEND FMHA_FWD_RDNA_GEN_BLOBS ${fwd_blob})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(FMHA_FWD_RDNA_GEN_BLOBS)
|
||||
set_property(SOURCE ${FMHA_FWD_RDNA_GEN_BLOBS}
|
||||
APPEND PROPERTY COMPILE_DEFINITIONS CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5)
|
||||
endif()
|
||||
|
||||
if(NOT FMHA_HAS_NON_RDNA_TARGET)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
@@ -15,7 +15,7 @@ Running the build recipe will produce the executable `tile_example_fmha_fwd`.
|
||||
|
||||
The executables reside in `bin` subdirectory of the build directory.
|
||||
|
||||
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`.
|
||||
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`.
|
||||
|
||||
> [!NOTE]
|
||||
> `cmake-ck-dev.sh` is a CMake wrapper.
|
||||
@@ -62,14 +62,17 @@ args:
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
-qscale n or 0, no scaling (default:n)
|
||||
1: per-tensor quantization.
|
||||
pt or 1, per-tensor scale
|
||||
bs or 2, block scale
|
||||
kvbs or 3, Q per-tensor, K/V per-page block scale, only in batch_prefill
|
||||
mx or 4, microscaling (exclusively for mxfp8/mxfp4)
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
@@ -161,7 +164,23 @@ We support sequence padding and variable-length processing in both batch and gro
|
||||
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
## FP8 support
|
||||
FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. Three fp8-based precision modes are available via `-prec`:
|
||||
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
| `-prec` value | Q/K/V input type | Output type | Description |
|
||||
|---|---|---|---|
|
||||
| `fp8` | fp8 | fp8 | Fully fp8: both inputs and output are in fp8 |
|
||||
| `fp8bf16` | fp8 | bf16 | Mixed precision: fp8 inputs, bf16 output — useful when the consumer expects a wider-range output format |
|
||||
| `fp8fp32` | fp8 | fp32 | Mixed precision: fp8 inputs, fp32 output — highest-precision output, suitable for debugging or further fp32 processing |
|
||||
|
||||
The following quantization scale modes are available via `-qscale`:
|
||||
|
||||
| `-qscale` value | Description |
|
||||
|---|---|
|
||||
| `n` or `0` | No quantization scale (default) |
|
||||
| `pt` or `1` | Per-tensor quantization scale — a single scale factor is applied to the entire tensor |
|
||||
| `bs` or `2` | Per-block quantization scale — a scale factor is applied per block of elements |
|
||||
| `kvbs` or `3` | Q per-tensor + K/V per-page block scale (batch_prefill only) |
|
||||
| `mx` or `4` | Microscaling (MX format), exclusively for `mxfp8` and `mxfp4` data types |
|
||||
|
||||
Currently only `-vlayout=r` (`seqlen*hdim` for V matrix) is supported for fp8 data types.
|
||||
|
||||
@@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
@@ -147,6 +148,7 @@ PIPELINE_MAP = {
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
|
||||
@@ -84,6 +84,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_sink},
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
@@ -124,7 +125,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -201,9 +202,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -247,6 +248,7 @@ class FmhaFwdApiTrait:
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
sink: str # t/f
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
@@ -343,6 +345,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_sink: str # t/f (StreamLLM sink tokens)
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
@@ -406,6 +409,11 @@ class FmhaFwdPipeline:
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
@@ -472,6 +480,7 @@ class FmhaFwdApiPool:
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -578,6 +587,7 @@ class FmhaFwdKernel:
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -617,6 +627,7 @@ class FmhaFwdKernel:
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
@@ -655,6 +666,7 @@ class KernelComponentFactory:
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
sink,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
@@ -663,12 +675,13 @@ class KernelComponentFactory:
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
# no need lse/dropout/sink kernels
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
@@ -684,7 +697,7 @@ class KernelComponentFactory:
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -701,20 +714,34 @@ class CustomFactory(KernelComponentFactory):
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
|
||||
targets: Optional[List[str]] = None
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
|
||||
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
|
||||
# non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different
|
||||
# buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets.
|
||||
has_non_gfx9 = targets is not None and any(
|
||||
not t.startswith("gfx9") for t in targets
|
||||
)
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
if has_non_gfx9:
|
||||
return api_pool, gen
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
# batch_prefill pipeline requires group mode (static_assert in pipeline problem)
|
||||
if mode != "group":
|
||||
continue
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
|
||||
):
|
||||
@@ -829,7 +856,7 @@ def write_blobs(
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
@@ -844,7 +871,7 @@ def list_blobs(
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && \
|
||||
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
|
||||
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
|
||||
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
#endif
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dpad == "t":
|
||||
return "a.hdim_q % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dvpad == "t":
|
||||
return "a.hdim_v % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
|
||||
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
|
||||
else:
|
||||
return "ck_tile::FmhaFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_kernel_header(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_hpad":
|
||||
return cls._KERNEL_HEADER_QR_HPAD
|
||||
return cls._KERNEL_HEADER
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_async_trload_v3":
|
||||
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
|
||||
return "fmha_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
|
||||
self
|
||||
)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
@@ -1144,6 +1179,32 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_FP16_BF16
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = super().get_rules()
|
||||
|
||||
def check_d128_tile_pipeline(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
if problem_ctx.dtype not in cls._DT_FP16_BF16:
|
||||
return True
|
||||
|
||||
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
|
||||
return True
|
||||
|
||||
# For (128, 128) head dims, partial-fragment support in qr_hpad removes the need
|
||||
# for the previous qr_hpad-specific handling that was added to avoid register spill.
|
||||
# qr_hpad now reuses the regular 128x64 tile choice.
|
||||
# The 64x64 tile remains disabled for qr_hpad because it is consistently slower
|
||||
# in our measurements.
|
||||
if kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 64:
|
||||
return kernel_ctx.pipeline.tag != "qr_hpad"
|
||||
|
||||
return True
|
||||
|
||||
rules.append(check_d128_tile_pipeline)
|
||||
return rules
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
@@ -1152,7 +1213,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
# max_seqlen_q cutoff retuned after the bf16 standard_cnan change.
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
|
||||
@@ -1179,7 +1241,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -1213,7 +1277,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")),
|
||||
# max_seqlen_q cutoff retuned after the bf16 standard_cnan change.
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
@@ -1251,7 +1316,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
):
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
|
||||
@@ -1452,6 +1452,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
@@ -1480,7 +1481,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
kHasSink_>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
|
||||
@@ -387,7 +387,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
|
||||
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API))
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_BATCH_PREFILL_API))
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
@@ -395,7 +395,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
page_block_size = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(page_block_size % 128 == 0))
|
||||
// batch_prefill supports flexible page sizes (not just multiples of 128)
|
||||
const bool need_128_aligned_page =
|
||||
(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API ||
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API);
|
||||
if(need_128_aligned_page && 0 < page_block_size && !(page_block_size % 128 == 0))
|
||||
{
|
||||
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
||||
<< std::endl;
|
||||
@@ -972,9 +976,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
|
||||
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
|
||||
// kvcache or group mode with padding enabled)
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
// batch_prefill (group+kvcache) also needs per-batch seqlen_k for VLLM_BLOCK_TABLE_2D
|
||||
const bool need_seqlen_k_buf = (mode == mode_enum::batch && use_kvcache) ||
|
||||
has_group_k_padding || (mode == mode_enum::group && use_kvcache);
|
||||
ck_tile::DeviceMem seqlen_k_buf(need_seqlen_k_buf ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
||||
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
||||
@@ -1013,9 +1018,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
||||
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
||||
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
seqlen_k_buf.ToDevice(need_seqlen_k_buf ? seqlen_ks.data() : nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -1146,6 +1149,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
traits.use_pagedkv = (0 < page_block_size);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_batch_prefill_traits,
|
||||
std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
traits.qscale_type = qscale.type;
|
||||
traits.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
|
||||
traits.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
|
||||
traits.page_size = page_block_size;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1498,6 +1512,67 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_batch_prefill_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
// Fields already set by the outer else block above:
|
||||
// bias_ptr, lse_ptr, o_ptr, seqlen_k, max_seqlen_q, scale_s,
|
||||
// logits_soft_cap, stride_bias/o, nhead/batch stride for bias/lse/o,
|
||||
// window_size_left/right, sink_size, mask_type.
|
||||
|
||||
// scale_p/scale_o: batch_prefill-specific fields absent from fmha_fwd_args.
|
||||
args.scale_p = 1.f;
|
||||
args.scale_o = 1.f;
|
||||
|
||||
// Dropout fields: the outer fmha_fwd_args branch sets these; set them here
|
||||
// for batch_prefill since it takes a separate inner branch.
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
args.stride_randval = stride_randval;
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
else
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
|
||||
// Paged KV: LINEAR_LAYOUT + VLLM_BLOCK_TABLE_2D
|
||||
// block_table_buf: [batch, max_blocks_per_seq] of physical page ids
|
||||
// seqlen_k_buf: [batch] of per-batch seqlen_k values
|
||||
args.num_total_pages = max_num_page_blocks;
|
||||
args.page_block_size = page_block_size;
|
||||
args.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
|
||||
args.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
|
||||
args.kv_indptr = nullptr;
|
||||
args.kv_page_indices = block_table_buf.GetDeviceBuffer();
|
||||
args.kv_last_page_lens = nullptr;
|
||||
args.seqlen_k_ptr = seqlen_k_buf.GetDeviceBuffer();
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
|
||||
// group mode required: seqstart_q is prefix-sum of per-batch seqlen_q
|
||||
args.seqstart_q_ptr = seqstart_q_buf.GetDeviceBuffer();
|
||||
|
||||
// batch_prefill LINEAR_LAYOUT strides for runner's K layout
|
||||
// [max_num_page_blocks, nhead_k, page_block_size, hdim]:
|
||||
// stride_k = hdim_q (token stride within one head's page slice)
|
||||
// nhead_stride_k = page_block_size * hdim_q (head stride)
|
||||
// batch_stride_k = nhead_k * page_block_size * hdim_q (page stride, already set)
|
||||
args.stride_k = hdim_q;
|
||||
args.nhead_stride_k = page_block_size * hdim_q;
|
||||
// V is row-major, same layout convention
|
||||
args.stride_v = hdim_v;
|
||||
args.nhead_stride_v = page_block_size * hdim_v;
|
||||
|
||||
// descale: not used for fp16/bf16
|
||||
args.q_descale_ptr = nullptr;
|
||||
args.k_descale_ptr = nullptr;
|
||||
args.v_descale_ptr = nullptr;
|
||||
args.nblock_stride_kv_block_descale = 0;
|
||||
args.nhead_stride_kv_block_descale = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1524,6 +1599,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
|
||||
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
// batch_prefill: group mode + paged KV, tested against the same CPU reference
|
||||
if(1 == num_splits && use_kvcache && mode == mode_enum::group)
|
||||
{
|
||||
fmha_batch_prefill_traits bp_traits;
|
||||
init_traits(bp_traits);
|
||||
|
||||
fmha_batch_prefill_args bp_args;
|
||||
init_args(bp_args);
|
||||
|
||||
const float ave_time = fmha_batch_prefill(bp_traits, bp_args, sc);
|
||||
if(ave_time >= 0.0f)
|
||||
return ave_time;
|
||||
}
|
||||
#endif // CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(1 == num_splits && use_kvcache)
|
||||
{
|
||||
@@ -1844,7 +1934,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
|
||||
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
// clang-format off
|
||||
@@ -1895,7 +1986,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
});
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
|
||||
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
if(is_v_rowmajor)
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for causal in 0 1 ; do
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Padding benchmark comparisons for v3 (batch mode only)
|
||||
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
|
||||
prec="fp16"
|
||||
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_v3_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
@@ -58,27 +58,45 @@ struct WeightPreshuffleInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
GemmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ struct UniversalInvoker
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
@@ -228,7 +227,6 @@ struct UniversalInvoker
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer>>;
|
||||
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
#if 0
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
#if 0
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
|
||||
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
|
||||
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
|
||||
@@ -5,14 +5,6 @@
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
|
||||
@@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups>>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "epilogue: " << GemmEpilogue::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
@@ -139,28 +139,48 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
|
||||
@@ -108,28 +108,48 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false, // FixedVectorSize
|
||||
1>>, // VectorSizeC
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
@@ -163,28 +163,48 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
@@ -84,7 +84,26 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
using GemmEpilogue =
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
false, // FixedVectorSize
|
||||
1>>, // VectorSizeC
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
@@ -104,8 +123,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
@@ -207,27 +207,44 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
TiledPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c>>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_tile_distr_enc_reg_map example_tile_distr_enc_reg_map.cpp)
|
||||
add_executable(tile_example_tile_distr_enc_calc example_tile_distr_enc_calc.cpp)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdio>
|
||||
#include <type_traits>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace mma;
|
||||
using F16 = fp16_t;
|
||||
using F32 = fp32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
using Target11 = decltype(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>());
|
||||
using Target12 = decltype(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>());
|
||||
|
||||
template <typename MmaOp>
|
||||
int check_tile_distr_enc()
|
||||
{
|
||||
using AEnc = typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding;
|
||||
using BEnc = typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding;
|
||||
using CEnc = typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding;
|
||||
|
||||
TileDistrEncRegMap<AEnc>::print();
|
||||
TileDistrEncRegMap<BEnc>::print();
|
||||
TileDistrEncRegMap<CEnc>::print();
|
||||
|
||||
// The only thing we check here is that CTranspose works as expected.
|
||||
using AEncTransp = typename TileDistrEncCalc<MmaOp, true>::AWarpDstrEncoding;
|
||||
using BEncTransp = typename TileDistrEncCalc<MmaOp, true>::BWarpDstrEncoding;
|
||||
using CEncTransp = typename TileDistrEncCalc<MmaOp, true>::CWarpDstrEncoding;
|
||||
|
||||
// When using TransposeC, the A and B matrix layouts should be swapped.
|
||||
static_assert(std::is_same<AEncTransp, BEnc>());
|
||||
static_assert(std::is_same<BEncTransp, AEnc>());
|
||||
|
||||
// Make sure the C matrix layout is transposed in the CTranspose case.
|
||||
int err = 0;
|
||||
for(index_t lane = 0; lane < TileDistrEncRegMap<CEnc>::num_lanes; lane++)
|
||||
{
|
||||
for(index_t vec = 0; vec < TileDistrEncRegMap<CEnc>::num_vector_items; vec++)
|
||||
{
|
||||
auto coords = TileDistrEncRegMap<CEnc>::calc_matrix_indices_from_lane_vector(lane, vec);
|
||||
auto coords_transp =
|
||||
TileDistrEncRegMap<CEncTransp>::calc_matrix_indices_from_lane_vector(lane, vec);
|
||||
|
||||
if(coords[0] != coords_transp[1] || coords[1] != coords_transp[0])
|
||||
{
|
||||
err = 1;
|
||||
printf("\033[31mLane %2d vec %2d maps to C matrix coords %2d %2d and transposed C "
|
||||
"matrix coords %2d %2d, inconsistent!\033[0m\n",
|
||||
lane,
|
||||
vec,
|
||||
coords[0],
|
||||
coords[1],
|
||||
coords_transp[0],
|
||||
coords_transp[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
// List of intrinsics to test.
|
||||
// clang-format off
|
||||
using Intrinsics = ck_tile::tuple<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
int main()
|
||||
{
|
||||
int err = 0;
|
||||
static_for<0, Intrinsics::size(), 1>{}([&](auto i) {
|
||||
using MmaOp = std::tuple_element_t<i.value, Intrinsics>;
|
||||
err |= check_tile_distr_enc<MmaOp>();
|
||||
});
|
||||
return err;
|
||||
}
|
||||
128
example/ck_tile/52_cshuffle_lds/CMakeLists.txt
Normal file
128
example/ck_tile/52_cshuffle_lds/CMakeLists.txt
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CShuffleLds LDS store/load microbenchmark suite
|
||||
# Measures LDS bandwidth and bank conflicts for different MFMA configurations
|
||||
|
||||
set(GENERATED_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
|
||||
file(MAKE_DIRECTORY "${GENERATED_SOURCE_DIR}")
|
||||
|
||||
# Core function: generate and build a benchmark executable
|
||||
function(add_cshuffle_lds_benchmark NAME A_TYPE B_TYPE ACC_TYPE O_TYPE M N M_WAVE N_WAVE M_XDL N_XDL K_XDL CONFIG_NAME)
|
||||
set(GENERATED_SOURCE "${GENERATED_SOURCE_DIR}/${NAME}.cpp")
|
||||
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/benchmark_template.cpp.in" "${GENERATED_SOURCE}" @ONLY)
|
||||
set_source_files_properties(${GENERATED_SOURCE} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${NAME} ${GENERATED_SOURCE})
|
||||
set_property(TARGET ${NAME} PROPERTY HIP_ARCHITECTURES ${SUPPORTED_GPU_TARGETS})
|
||||
target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/test ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_link_libraries(${NAME} PRIVATE hip::device)
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(${NAME} PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Type-specific wrappers (derive name and config from parameters)
|
||||
function(add_fp16_benchmark M N M_WAVE N_WAVE M_XDL N_XDL K_XDL)
|
||||
set(NAME "bench_lds_fp16_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
set(CONFIG "FP16_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
add_cshuffle_lds_benchmark(${NAME} "ck_tile::half_t" "ck_tile::half_t" "float" "ck_tile::half_t"
|
||||
${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL} ${CONFIG})
|
||||
endfunction()
|
||||
|
||||
function(add_fp8_fp16_benchmark M N M_WAVE N_WAVE M_XDL N_XDL K_XDL)
|
||||
set(NAME "bench_lds_fp8_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}_fp16")
|
||||
set(CONFIG "FP8_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}_fp16")
|
||||
add_cshuffle_lds_benchmark(${NAME} "ck_tile::fp8_t" "ck_tile::fp8_t" "float" "ck_tile::half_t"
|
||||
${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL} ${CONFIG})
|
||||
endfunction()
|
||||
|
||||
function(add_fp8_fp8_benchmark M N M_WAVE N_WAVE M_XDL N_XDL K_XDL)
|
||||
set(NAME "bench_lds_fp8_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}_fp8")
|
||||
set(CONFIG "FP8_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}_fp8")
|
||||
add_cshuffle_lds_benchmark(${NAME} "ck_tile::fp8_t" "ck_tile::fp8_t" "float" "ck_tile::fp8_t"
|
||||
${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL} ${CONFIG})
|
||||
endfunction()
|
||||
|
||||
function(add_fp32_benchmark M N M_WAVE N_WAVE M_XDL N_XDL K_XDL)
|
||||
set(NAME "bench_lds_fp32_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
set(CONFIG "FP32_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
add_cshuffle_lds_benchmark(${NAME} "float" "float" "float" "float"
|
||||
${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL} ${CONFIG})
|
||||
endfunction()
|
||||
|
||||
function(add_bf16_benchmark M N M_WAVE N_WAVE M_XDL N_XDL K_XDL)
|
||||
set(NAME "bench_lds_bf16_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
set(CONFIG "BF16_${M_XDL}x${N_XDL}x${K_XDL}_${M_WAVE}x${N_WAVE}")
|
||||
add_cshuffle_lds_benchmark(${NAME} "ck_tile::bf16_t" "ck_tile::bf16_t" "float" "ck_tile::bf16_t"
|
||||
${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL} ${CONFIG})
|
||||
endfunction()
|
||||
|
||||
# Helper to add benchmarks for all wave layouts of a given MFMA tile
|
||||
# Block tile M = M_XDL * M_WAVE, N = N_XDL * N_WAVE (must be divisible, here we use single iteration)
|
||||
macro(add_benchmarks_for_mfma FUNC M_XDL N_XDL K_XDL)
|
||||
foreach(WAVE_LAYOUT "4;1" "2;2" "1;4")
|
||||
list(GET WAVE_LAYOUT 0 M_WAVE)
|
||||
list(GET WAVE_LAYOUT 1 N_WAVE)
|
||||
math(EXPR M "${M_XDL} * ${M_WAVE}")
|
||||
math(EXPR N "${N_XDL} * ${N_WAVE}")
|
||||
cmake_language(CALL ${FUNC} ${M} ${N} ${M_WAVE} ${N_WAVE} ${M_XDL} ${N_XDL} ${K_XDL})
|
||||
endforeach()
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# FP32 benchmarks
|
||||
#
|
||||
# MFMA tiles: 32x32x4, 32x32x8, 16x16x4, 16x16x8, 16x16x16
|
||||
add_benchmarks_for_mfma(add_fp32_benchmark 32 32 4)
|
||||
add_benchmarks_for_mfma(add_fp32_benchmark 32 32 8)
|
||||
add_benchmarks_for_mfma(add_fp32_benchmark 16 16 4)
|
||||
add_benchmarks_for_mfma(add_fp32_benchmark 16 16 8)
|
||||
add_benchmarks_for_mfma(add_fp32_benchmark 16 16 16)
|
||||
|
||||
#
|
||||
# FP16 benchmarks
|
||||
#
|
||||
# MFMA tiles: 32x32x8, 32x32x16, 16x16x16, 4x64x16, 64x4x16
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 32 32 8)
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 32 32 16)
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 16 16 16)
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 4 64 16)
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 64 4 16)
|
||||
|
||||
#
|
||||
# FP8 -> FP16 benchmarks
|
||||
#
|
||||
# MFMA tiles: 32x32x16, 16x16x32
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 32 32 16)
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 16 16 32)
|
||||
|
||||
#
|
||||
# FP8 -> FP8 benchmarks
|
||||
#
|
||||
# MFMA tiles: 32x32x16, 16x16x32
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 32 32 16)
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 16 16 32)
|
||||
|
||||
#
|
||||
# gfx950-only configurations
|
||||
#
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
# FP16: 16x16x32
|
||||
add_benchmarks_for_mfma(add_fp16_benchmark 16 16 32)
|
||||
|
||||
# BF16: 16x16x64 (gfx950-only, uses 16x16x32 base instruction)
|
||||
# Other BF16 tiles have same LDS behavior as FP16 since both are 2-byte types
|
||||
add_benchmarks_for_mfma(add_bf16_benchmark 16 16 64)
|
||||
|
||||
# FP8 -> FP16: 32x32x32, 32x32x64, 16x16x64, 16x16x128
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 32 32 32)
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 32 32 64)
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 16 16 64)
|
||||
add_benchmarks_for_mfma(add_fp8_fp16_benchmark 16 16 128)
|
||||
|
||||
# FP8 -> FP8: 32x32x32, 32x32x64, 16x16x64, 16x16x128
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 32 32 32)
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 32 32 64)
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 16 16 64)
|
||||
add_benchmarks_for_mfma(add_fp8_fp8_benchmark 16 16 128)
|
||||
endif()
|
||||
61
example/ck_tile/52_cshuffle_lds/README.md
Normal file
61
example/ck_tile/52_cshuffle_lds/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# CShuffleLds LDS Microbenchmarks
|
||||
|
||||
Microbenchmark suite for measuring LDS (Local Data Share) bandwidth and bank conflicts in the CShuffleEpilogue cross-lane shuffle patterns.
|
||||
|
||||
## What This Measures
|
||||
|
||||
The CShuffleEpilogue uses LDS to redistribute GEMM output tiles from MFMA register layout to thread-raked layout for efficient global memory writes. This benchmark isolates the LDS store/load operations to measure:
|
||||
|
||||
1. **Store bandwidth** - Writing accumulator tiles to LDS (MFMA → LDS)
|
||||
2. **Load bandwidth** - Reading shuffled tiles from LDS (LDS → thread-raked)
|
||||
3. **Bank conflicts** - LDS bank conflicts during store/load (via rocprofv3)
|
||||
|
||||
## Configurations
|
||||
|
||||
Benchmarks are generated for all combinations of:
|
||||
|
||||
- **FP32 MFMA tiles**: 32x32x4, 32x32x8, 16x16x4, 16x16x8, 16x16x16
|
||||
- **FP16 MFMA tiles**: 32x32x8, 32x32x16, 16x16x16, 4x64x16, 64x4x16
|
||||
- **FP8 MFMA tiles**: 32x32x16, 16x16x32 (output FP16 or FP8)
|
||||
- **Wave layouts**: 4x1, 2x2, 1x4 (block size = MFMA tile × wave layout)
|
||||
|
||||
**gfx950-only configurations:**
|
||||
- **FP16**: 16x16x32
|
||||
- **BF16**: 16x16x64 (uses gfx950-only 16x16x32 base instruction)
|
||||
- **FP8**: 32x32x32, 32x32x64, 16x16x64, 16x16x128 (output FP16 or FP8)
|
||||
|
||||
Each configuration produces two measurements: Store and Load.
|
||||
|
||||
## Building
|
||||
|
||||
```bash
|
||||
cmake -G Ninja -B build -S . \
|
||||
-DGPU_TARGETS=gfx950 \
|
||||
-DBUILD_CK_EXAMPLES=ON \
|
||||
-DBUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS=ON
|
||||
|
||||
ninja -C build bench_lds_fp8_16x16x128_2x2_fp8 # Single benchmark
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Run a single benchmark
|
||||
./build/bin/bench_lds_fp8_16x16x128_2x2_fp8 --warmup 3 --iters 10
|
||||
|
||||
# Profile with rocprofv3 for bank conflicts
|
||||
cat > counters.txt <<EOF
|
||||
pmc: SQ_LDS_BANK_CONFLICT SQ_INSTS_LDS
|
||||
EOF
|
||||
|
||||
rocprofv3 -i counters.txt -d output/ -- \
|
||||
./build/bin/bench_lds_fp8_16x16x128_2x2_fp8
|
||||
```
|
||||
|
||||
## Implementation
|
||||
|
||||
- **Generic kernels**: `include/ck_tile/utility/tile_load_store_microkernels.hpp`
|
||||
- **Setup adapters**: `benchmark_cshuffle_lds.hpp`
|
||||
- **Template generation**: `benchmark_template.cpp.in`
|
||||
|
||||
The benchmark uses CK's `launch_kernel` infrastructure for timing and `make_kernel` for functor-based kernel dispatch.
|
||||
122
example/ck_tile/52_cshuffle_lds/benchmark_cshuffle_lds.hpp
Normal file
122
example/ck_tile/52_cshuffle_lds/benchmark_cshuffle_lds.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file benchmark_cshuffle_lds.hpp
|
||||
* @brief LDS benchmark setup for CShuffleEpilogue.
|
||||
*
|
||||
* Provides Setup adapters that extract LDS descriptor and distribution
|
||||
* from CShuffleEpilogue for use with generic tile benchmark kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/utility/tile_load_store_microkernels.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Create CShuffleEpilogue type from benchmark parameters.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename ODataType,
|
||||
index_t kM,
|
||||
index_t kN,
|
||||
index_t MWave,
|
||||
index_t NWave,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t KPerXdl>
|
||||
using BenchmarkEpilogue = CShuffleEpilogue<CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
tuple<>,
|
||||
AccDataType,
|
||||
ODataType,
|
||||
tuple<>,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
element_wise::PassThrough,
|
||||
kM,
|
||||
kN,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
false>>;
|
||||
|
||||
/**
|
||||
* @brief Setup for LDS store benchmark - adapts CShuffleEpilogue for tile benchmark.
|
||||
*/
|
||||
template <typename Epilogue>
|
||||
struct LdsStoreSetup
|
||||
{
|
||||
using ODataType = typename Epilogue::ODataType;
|
||||
static constexpr index_t kBlockSize = Epilogue::kBlockSize;
|
||||
static constexpr index_t kBytes =
|
||||
Epilogue::MPerIterationShuffle * Epilogue::NPerIterationShuffle * sizeof(ODataType);
|
||||
static constexpr auto lds_desc =
|
||||
Epilogue::template MakeLdsBlockDescriptor<typename Epilogue::Problem>();
|
||||
static constexpr auto distr =
|
||||
make_static_tile_distribution(Epilogue::MakeLdsDistributionEncode());
|
||||
|
||||
CK_TILE_DEVICE static auto create()
|
||||
{
|
||||
alignas(16) __shared__ char smem[Epilogue::GetSmemSize()];
|
||||
|
||||
auto lds_view =
|
||||
make_tensor_view<address_space_enum::lds>(reinterpret_cast<ODataType*>(smem), lds_desc);
|
||||
|
||||
auto window = make_tile_window(lds_view,
|
||||
make_tuple(number<Epilogue::MPerIterationShuffle>{},
|
||||
number<Epilogue::NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
distr);
|
||||
|
||||
auto tile = make_static_distributed_tensor<ODataType>(distr);
|
||||
|
||||
return make_tuple(window, tile);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Setup for LDS load benchmark - adapts CShuffleEpilogue for tile benchmark.
|
||||
*/
|
||||
template <typename Epilogue>
|
||||
struct LdsLoadSetup
|
||||
{
|
||||
using ODataType = typename Epilogue::ODataType;
|
||||
static constexpr index_t kBlockSize = Epilogue::kBlockSize;
|
||||
static constexpr index_t kBytes =
|
||||
Epilogue::MPerIterationShuffle * Epilogue::NPerIterationShuffle * sizeof(ODataType);
|
||||
static constexpr auto lds_desc =
|
||||
Epilogue::template MakeLdsBlockDescriptor<typename Epilogue::Problem>();
|
||||
|
||||
using ReadPattern =
|
||||
tile_distribution_encoding_pattern_2d<Epilogue::kBlockSize,
|
||||
Epilogue::MPerIterationShuffle,
|
||||
Epilogue::NPerIterationShuffle,
|
||||
Epilogue::GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
static constexpr auto read_distr = ReadPattern::make_2d_static_tile_distribution();
|
||||
|
||||
CK_TILE_DEVICE static auto create()
|
||||
{
|
||||
alignas(16) __shared__ char smem[Epilogue::GetSmemSize()];
|
||||
|
||||
auto lds_view =
|
||||
make_tensor_view<address_space_enum::lds>(reinterpret_cast<ODataType*>(smem), lds_desc);
|
||||
|
||||
return make_tile_window(lds_view,
|
||||
make_tuple(number<Epilogue::MPerIterationShuffle>{},
|
||||
number<Epilogue::NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
read_distr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
100
example/ck_tile/52_cshuffle_lds/benchmark_template.cpp.in
Normal file
100
example/ck_tile/52_cshuffle_lds/benchmark_template.cpp.in
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "benchmark_cshuffle_lds.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
using Epilogue = ck_tile::BenchmarkEpilogue<
|
||||
@A_TYPE@, @B_TYPE@, @ACC_TYPE@, @O_TYPE@,
|
||||
@M@, @N@, @M_WAVE@, @N_WAVE@, @M_XDL@, @N_XDL@, @K_XDL@>;
|
||||
|
||||
using StoreSetup = ck_tile::LdsStoreSetup<Epilogue>;
|
||||
using LoadSetup = ck_tile::LdsLoadSetup<Epilogue>;
|
||||
|
||||
void print_help(const char* prog)
|
||||
{
|
||||
std::cout << "Usage: " << prog << " [options]\n"
|
||||
<< "\n"
|
||||
<< "LDS microbenchmark for CShuffleEpilogue (@CONFIG_NAME@)\n"
|
||||
<< "\n"
|
||||
<< "Options:\n"
|
||||
<< " -w, --warmup <N> Warmup iterations (default: 3)\n"
|
||||
<< " -i, --iters <N> Benchmark iterations (default: 10)\n"
|
||||
<< " -h, --help Show this help message\n"
|
||||
<< "\n"
|
||||
<< "Configuration:\n"
|
||||
<< " MFMA tile: @M_XDL@x@N_XDL@x@K_XDL@\n"
|
||||
<< " Wave layout: @M_WAVE@x@N_WAVE@\n"
|
||||
<< " Block tile: @M@x@N@\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int warmup = 3;
|
||||
int iters = 10;
|
||||
|
||||
for (int i = 1; i < argc; ++i)
|
||||
{
|
||||
if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0)
|
||||
{
|
||||
print_help(argv[0]);
|
||||
return 0;
|
||||
}
|
||||
else if ((std::strcmp(argv[i], "-w") == 0 || std::strcmp(argv[i], "--warmup") == 0) && i + 1 < argc)
|
||||
{
|
||||
int val = std::atoi(argv[++i]);
|
||||
if (val <= 0)
|
||||
{
|
||||
std::cerr << "Error: --warmup requires a positive integer\n";
|
||||
return 1;
|
||||
}
|
||||
warmup = val;
|
||||
}
|
||||
else if ((std::strcmp(argv[i], "-i") == 0 || std::strcmp(argv[i], "--iters") == 0) && i + 1 < argc)
|
||||
{
|
||||
int val = std::atoi(argv[++i]);
|
||||
if (val <= 0)
|
||||
{
|
||||
std::cerr << "Error: --iters requires a positive integer\n";
|
||||
return 1;
|
||||
}
|
||||
iters = val;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unknown option: " << argv[i] << "\n";
|
||||
print_help(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "=== @CONFIG_NAME@ ===" << std::endl;
|
||||
|
||||
ck_tile::stream_config stream{nullptr, true, 0, warmup, iters, true};
|
||||
|
||||
// Store benchmark
|
||||
{
|
||||
float ms = ck_tile::launch_kernel(stream,
|
||||
ck_tile::make_kernel(ck_tile::StoreTile<StoreSetup>{},
|
||||
dim3(1), dim3(StoreSetup::kBlockSize), 0));
|
||||
double gb_s = (double(StoreSetup::kBytes) / 1e9) / (ms / 1e3);
|
||||
std::cout << "Store: " << ms << " ms, " << gb_s << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
// Load benchmark
|
||||
{
|
||||
float ms = ck_tile::launch_kernel(stream,
|
||||
ck_tile::make_kernel(ck_tile::LoadTile<LoadSetup>{},
|
||||
dim3(1), dim3(LoadSetup::kBlockSize), 0));
|
||||
double gb_s = (double(LoadSetup::kBytes) / 1e9) / (ms / 1e3);
|
||||
std::cout << "Load: " << ms << " ms, " << gb_s << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -33,4 +33,7 @@ add_subdirectory(41_batched_contraction)
|
||||
add_subdirectory(42_mx_gemm)
|
||||
add_subdirectory(50_sparse_attn)
|
||||
add_subdirectory(51_tile_distr_enc_reg_map)
|
||||
if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS)
|
||||
add_subdirectory(52_cshuffle_lds)
|
||||
endif()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user