mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Merge branch 'develop' into users/yiding12/fmha-bwd-workspace
This commit is contained in:
@@ -10,27 +10,31 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
|
||||
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
mkdir rocm-libraries && cd rocm-libraries && \
|
||||
git init -q && \
|
||||
git remote add origin https://github.com/ROCm/rocm-libraries.git && \
|
||||
git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/composablekernel && \
|
||||
git checkout "$CK_AITER_BRANCH" && \
|
||||
git checkout FETCH_HEAD && \
|
||||
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
|
||||
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
|
||||
mv projects/composablekernel ../ck && \
|
||||
cd ../ck && rm -rf ../rocm-libraries && \
|
||||
git init && \
|
||||
git init -b "$LOCAL_BRANCH" && \
|
||||
git config user.name "assistant-librarian[bot]" && \
|
||||
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
|
||||
git branch -m "$CK_AITER_BRANCH" && git add -A && \
|
||||
git add -A && \
|
||||
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \
|
||||
else \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \
|
||||
LOCAL_BRANCH="$CK_AITER_BRANCH" ; \
|
||||
fi && \
|
||||
cd /home/jenkins/workspace && rm -rf aiter && \
|
||||
git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
|
||||
cd aiter && \
|
||||
rm -rf 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \
|
||||
python3 setup.py develop && \
|
||||
groupadd -g 1001 jenkins && \
|
||||
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
|
||||
|
||||
@@ -12,27 +12,31 @@ RUN set -x ; \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
|
||||
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
mkdir rocm-libraries && cd rocm-libraries && \
|
||||
git init -q && \
|
||||
git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
|
||||
git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/composablekernel && \
|
||||
git checkout "$CK_FA_BRANCH" && \
|
||||
git checkout FETCH_HEAD && \
|
||||
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
|
||||
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
|
||||
mv projects/composablekernel ../ck && \
|
||||
cd ../ck && rm -rf ../rocm-libraries && \
|
||||
git init && \
|
||||
git init -b "$LOCAL_BRANCH" && \
|
||||
git config user.name "assistant-librarian[bot]" && \
|
||||
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
|
||||
git branch -m "$CK_FA_BRANCH" && git add -A && \
|
||||
git add -A && \
|
||||
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \
|
||||
else \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \
|
||||
LOCAL_BRANCH="$CK_FA_BRANCH" ; \
|
||||
fi && \
|
||||
cd /home/jenkins/workspace && rm -rf flash-attention && \
|
||||
git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \
|
||||
cd flash-attention && \
|
||||
rm -rf csrc/composable_kernel/ && \
|
||||
git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
|
||||
git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
|
||||
MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \
|
||||
groupadd -g 1001 jenkins && \
|
||||
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
|
||||
|
||||
17
Jenkinsfile
vendored
17
Jenkinsfile
vendored
@@ -840,8 +840,10 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
|
||||
dir("projects/composablekernel"){
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
|
||||
}
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||
@@ -918,7 +920,7 @@ def Build_CK(Map conf=[:]){
|
||||
sh "projects/composablekernel/script/run_inductor_tests.sh"
|
||||
}
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("projects/composablekernel/script"){
|
||||
dir("projects/composablekernel/script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){
|
||||
// run full tests on gfx90a or gfx942
|
||||
@@ -1017,6 +1019,13 @@ def process_results(Map conf=[:]){
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx950"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}."
|
||||
}
|
||||
|
||||
}
|
||||
if (params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
@@ -1191,7 +1200,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
|
||||
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
|
||||
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
|
||||
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
|
||||
CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME
|
||||
CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME)
|
||||
|
||||
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''
|
||||
|
||||
|
||||
@@ -108,28 +108,35 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_host(out_g_n_k_wos_desc);
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
PassThrough>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
wei,
|
||||
out_host,
|
||||
c_host,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
out_host.ForEach([&](auto&, auto idx)
|
||||
{
|
||||
out_element_op(out_host(idx), c_host(idx));
|
||||
});
|
||||
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
pass &=
|
||||
|
||||
@@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import (
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.arch import ArchTrait
|
||||
from codegen.utils import update_file
|
||||
|
||||
# Architecture trait for kernels requiring global_load_lds (CDNA3+).
|
||||
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
|
||||
CDNA3_PLUS_ARCH = ArchTrait(
|
||||
"cdna3_plus",
|
||||
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
|
||||
)
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
@@ -34,6 +42,10 @@ DTYPE_BITS = {
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
# Element size in bytes per dtype, used by the auto-generated dispatcher to
|
||||
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
|
||||
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
@@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
KV_LOAD_MODE_ENUM_MAP = {
|
||||
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
|
||||
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
@@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_sink},
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
{F_kv_lookup_table},
|
||||
{F_kv_load_mode}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -140,10 +159,13 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
@@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
constexpr int kElementBytes = {F_element_bytes};
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
@@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
+ ("-gload" if self.use_global_load else "-bload")
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -481,6 +506,7 @@ class FmhaFwdApiPool:
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -488,7 +514,10 @@ class FmhaFwdApiPool:
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
F_if=if_i,
|
||||
F_dtype=dtype,
|
||||
F_element_bytes=DTYPE_BYTES[dtype],
|
||||
F_hdim_case=per_hdim_case,
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
@@ -539,6 +568,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -588,6 +618,10 @@ class FmhaFwdKernel:
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
|
||||
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
|
||||
if self.F_use_global_load
|
||||
else "true",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -595,6 +629,7 @@ class FmhaFwdKernel:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ ("gload_" if self.F_use_global_load else "bload_")
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -632,6 +667,7 @@ class FmhaFwdKernel:
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
use_global_load=self.F_use_global_load,
|
||||
)
|
||||
|
||||
|
||||
@@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory):
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
|
||||
targets: Optional[List[str]] = None
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
targets: Optional[List[str]] = None,
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
|
||||
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
|
||||
@@ -837,6 +876,25 @@ def get_fwd_blobs(
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
|
||||
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
|
||||
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
|
||||
# (slower, handles >2GB).
|
||||
if page_size < tile.F_bn0:
|
||||
k_global_load = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
F_use_global_load=True,
|
||||
)
|
||||
api_pool.register_traits(k_global_load.api_trait())
|
||||
gen.append(k_global_load)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -856,7 +914,9 @@ def write_blobs(
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
@@ -871,7 +931,9 @@ def list_blobs(
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
_, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
|
||||
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
|
||||
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
|
||||
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
|
||||
// Inputs are taken as plain integers so the helper has no template parameter
|
||||
// and can be called from each codegen-emitted dispatcher arm with the arm's
|
||||
// compile-time kN0 / element_bytes substituted as constants.
|
||||
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
|
||||
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t kN0,
|
||||
ck_tile::index_t num_total_pages,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t element_bytes)
|
||||
{
|
||||
// Promote every operand to long_index_t so overflow is impossible regardless
|
||||
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
|
||||
// * batch_stride_k * element_bytes` only works because of left-to-right
|
||||
// associativity — a future reorder of the operands would silently truncate.
|
||||
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
|
||||
static_cast<ck_tile::long_index_t>(batch_stride_k) *
|
||||
static_cast<ck_tile::long_index_t>(element_bytes);
|
||||
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
|
||||
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
|
||||
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
@@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static constexpr auto kKVLoadMode = kKVLoadMode_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
|
||||
@@ -33,8 +33,9 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool TransposeC = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlops_pipeline_base
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
ALdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
@@ -399,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
BLdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
|
||||
@@ -32,12 +32,13 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool DirectLoad = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
// Supported for Direct Load and V1
|
||||
if constexpr(LdsScalarLoadToVgpr)
|
||||
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
|
||||
{
|
||||
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
|
||||
}
|
||||
@@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>{};
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
@@ -747,7 +747,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -772,7 +773,8 @@ template <index_t BlockSize,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
bool LdsScalarLoadToVgpr>
|
||||
bool ALdsScalarLoadToVgpr,
|
||||
bool BLdsScalarLoadToVgpr>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -793,7 +795,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -814,7 +817,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -837,7 +841,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>;
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
@@ -856,6 +857,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
|
||||
@@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
|
||||
@@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
|
||||
@@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
|
||||
@@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
|
||||
@@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create initial descriptors with hack=false to check compactness
|
||||
const auto descs_initial =
|
||||
|
||||
@@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
|
||||
@@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes are divisible by k_batch
|
||||
|
||||
@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
static constexpr bool ALdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
static constexpr bool BLdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
|
||||
// Note: Direct load use layout to create proper block and mmtile descriptor
|
||||
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor>,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor>,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
DirectLoad>;
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
@@ -625,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes match product of dimensions
|
||||
|
||||
@@ -162,6 +162,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -136,6 +136,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -13,6 +13,13 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in
|
||||
/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)).
|
||||
inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept
|
||||
{
|
||||
return k_batch < 1 ? index_t{1} : k_batch;
|
||||
}
|
||||
|
||||
struct DeviceProperties
|
||||
{
|
||||
DeviceProperties()
|
||||
@@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index
|
||||
const int max_capacity = max_occupancy * device_properties.num_cu_;
|
||||
|
||||
ck::index_t k_batch = 1;
|
||||
if(grid_size <= 0)
|
||||
{
|
||||
return k_batch;
|
||||
}
|
||||
const auto optimal_split =
|
||||
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
|
||||
@@ -66,7 +66,9 @@ template <typename ALayout,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
: public GridwiseGemm_xdl_cshuffle_base<
|
||||
ALayout,
|
||||
@@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
|
||||
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
|
||||
{
|
||||
|
||||
if constexpr(!DirectLoad)
|
||||
{
|
||||
return desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
|
||||
const index_t MN = desc.GetLength(I1);
|
||||
|
||||
const auto desc_unmerged = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto desc_permuted = transform_tensor_descriptor(
|
||||
desc_unmerged,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(MN, K0Number)),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave,
|
||||
index_t MNWaves,
|
||||
index_t MNPerXdl,
|
||||
bool IsKContinous,
|
||||
typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
if constexpr(DirectLoad && IsKContinous)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto desc = transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
@@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave,
|
||||
MWaves,
|
||||
MPerXdl,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value>(
|
||||
ABlockDesc_AK0_M_AK1{});
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_BK0_N_BK1>
|
||||
@@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave,
|
||||
NWaves,
|
||||
NPerXdl,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value>(
|
||||
BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
struct Problem
|
||||
@@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
|
||||
// or N dimension)
|
||||
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
// static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
LdsScalarLoadToVgpr>())>;
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>())>;
|
||||
|
||||
template <typename DeviceArch>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
|
||||
@@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
make_multi_index(static_cast<index_t>(block_idx_x)));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -21,6 +21,10 @@ template <index_t NDimSpatial,
|
||||
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
|
||||
struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
// Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors).
|
||||
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
|
||||
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@ template <index_t NDimSpatial,
|
||||
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
|
||||
struct TransformConvBwdWeightToGemmV2
|
||||
{
|
||||
// Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in
|
||||
// integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime).
|
||||
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
|
||||
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
|
||||
@@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
// Flat async load from global memory to LDS using 64-bit global addressing.
|
||||
// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds
|
||||
// INT32_MAX (2GB) byte offset on the SRD voffset path.
|
||||
//
|
||||
// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!!
|
||||
//
|
||||
// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3:
|
||||
// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`).
|
||||
// M0 does NOT appear as an operand of these instructions or of the inline
|
||||
// asm below — the compiler cannot see the dependency. Caller must:
|
||||
//
|
||||
// 1. Initialize M0 once before the load loop:
|
||||
// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));`
|
||||
// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to
|
||||
// enforce this. Direct VALU writes to M0 are illegal.
|
||||
//
|
||||
// 2. Advance M0 between successive issues:
|
||||
// `m0_inc_with_memory(size_per_issue);`
|
||||
// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path
|
||||
// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently
|
||||
// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses
|
||||
// M0[15:0] as a raw byte offset).
|
||||
//
|
||||
// 3. Never bundle `m0_inc_with_memory` and the next call to this
|
||||
// function into a single inline asm. The compiler auto-inserts a
|
||||
// hazard NOP between an SALU write to M0 and the consuming
|
||||
// `global_load_lds_*`; bundling bypasses that and may read stale M0.
|
||||
//
|
||||
// The "memory" clobber on this asm is load-bearing: it prevents the
|
||||
// compiler from reordering this load across other M0-touching helpers
|
||||
// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered).
|
||||
//
|
||||
// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950):
|
||||
// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000
|
||||
// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both
|
||||
// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but
|
||||
// supported by the LLVM AMDGPU backend.
|
||||
//
|
||||
// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series).
|
||||
template <unsigned num_dwords, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if !defined(__gfx94__) && !defined(__gfx950__)
|
||||
static_assert(always_false_v<integral_constant<unsigned, num_dwords>>,
|
||||
"global_load_lds requires CDNA3+ (gfx940/gfx950). "
|
||||
"Ensure kKVLoadMode is BUFFER_LOAD on this architecture.");
|
||||
#endif
|
||||
|
||||
static_assert(num_dwords == 1 || num_dwords == 4,
|
||||
"global_load_lds supports num_dwords == 1 or 4 only "
|
||||
"(2 dwords does not exist on any supported arch; "
|
||||
"3 dwords only on CDNA4 and unused in FMHA pipeline)");
|
||||
|
||||
// Inline asm: only the global address is an explicit operand. The LDS
|
||||
// destination is implicit via M0 (see contract above). `"=r"(smem)` is a
|
||||
// SSA scheduling anchor only — `smem` is NOT written by this asm; the
|
||||
// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`.
|
||||
#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \
|
||||
if constexpr(pre_nop) \
|
||||
asm volatile("s_nop 4\n" instr " %1, off offset:0" \
|
||||
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
|
||||
: "v"(global_addr) \
|
||||
: "memory" /*prevents reorder across m0_{set,inc}*/); \
|
||||
else \
|
||||
asm volatile(instr " %1, off offset:0" \
|
||||
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
|
||||
: "v"(global_addr) \
|
||||
: "memory" /*prevents reorder across m0_{set,inc}*/);
|
||||
|
||||
if constexpr(num_dwords == 1)
|
||||
{
|
||||
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword");
|
||||
}
|
||||
else if constexpr(num_dwords == 4)
|
||||
{
|
||||
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4");
|
||||
}
|
||||
#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE thread_buffer<int8_t, N>
|
||||
|
||||
@@ -45,9 +45,29 @@ template <typename BottomTensorView_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
typename YsGatherDims = sequence<0>>
|
||||
typename YsGatherDims = sequence<0>,
|
||||
bool kUseGlobalLoad_ = false>
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
static constexpr bool kUseGlobalLoad = kUseGlobalLoad_;
|
||||
|
||||
#if !defined(__gfx94__) && !defined(__gfx950__)
|
||||
// global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950).
|
||||
// On other architectures, kUseGlobalLoad must be false.
|
||||
static_assert(!kUseGlobalLoad_,
|
||||
"kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). "
|
||||
"This kernel should not be instantiated on this architecture.");
|
||||
#endif
|
||||
|
||||
// Empty placeholder used by the SRD instantiation so physical_pages_ and
|
||||
// page_stride_elements_ occupy zero bytes there (combined with
|
||||
// [[no_unique_address]] on the member declarations). Access sites are all
|
||||
// inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD
|
||||
// mode, so no caller needs to change.
|
||||
struct gl_field_empty_t
|
||||
{
|
||||
};
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
@@ -233,15 +253,22 @@ struct tile_scatter_gather
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx,
|
||||
const ValidArray& valids)
|
||||
const ValidArray& valids,
|
||||
index_t page_stride_elements = 0)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
physical_pages_{},
|
||||
page_stride_elements_{},
|
||||
valids_{valids},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
page_stride_elements_ = page_stride_elements;
|
||||
}
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
@@ -357,6 +384,34 @@ struct tile_scatter_gather
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for
|
||||
// SRD num_records control. Use to set max range when SRD is rebased per-tile
|
||||
// (page_size >= kN0 path): each rebased SRD only needs to cover one page; without
|
||||
// this the SRD claims validity for memory beyond the allocated buffer, which can
|
||||
// fault on gfx950 page-table validation.
|
||||
//
|
||||
// Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element
|
||||
// count and is divided by PackedSize before being stored. For PackedSize=1
|
||||
// (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4)
|
||||
// skipping it would over-report num_records by 2x and silently mask OOB on SRD
|
||||
// reads. batch_prefill currently does not exercise the packed-type path, but this
|
||||
// setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must
|
||||
// honor the same invariant the ctor enforces.
|
||||
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size)
|
||||
{
|
||||
// Hint the optimizer that size is positive without inserting a runtime
|
||||
// branch. Using <cassert> assert() here corrupted gfx950 batch_prefill
|
||||
// output: the __assert_fail handler's SGPR pressure forced the K-SRD
|
||||
// register window to be reused as scratch and scattered the SRD writes
|
||||
// across two conditional branches, which gfx950's packed
|
||||
// buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it
|
||||
// via per-tile single-dword loads). __builtin_assume is hint-only —
|
||||
// no branch, no scratch SGPRs, no codegen impact.
|
||||
__builtin_assume(size > 0);
|
||||
using BufType = remove_cvref_t<decltype(bottom_tensor_view_.buf_)>;
|
||||
bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
@@ -458,7 +513,21 @@ struct tile_scatter_gather
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
// Global load mode: 64-bit typed pointer arithmetic
|
||||
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto physical_page = physical_pages_[idx_gather];
|
||||
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
|
||||
const long_index_t total_offset =
|
||||
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
|
||||
coord_offset + page_offset;
|
||||
const auto* addr = base_ptr + total_offset;
|
||||
vector_t v;
|
||||
__builtin_memcpy(&v, addr, sizeof(vector_t));
|
||||
return v;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
@@ -680,7 +749,23 @@ struct tile_scatter_gather
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
// Global load mode: global_load_lds with 64-bit address
|
||||
constexpr index_t vector_size =
|
||||
sizeof(vector_t) / sizeof(uint32_t); // dwords per vector
|
||||
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto physical_page = physical_pages_[idx_gather];
|
||||
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
|
||||
const long_index_t total_offset =
|
||||
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
|
||||
coord_offset + page_offset;
|
||||
const auto* addr = base_ptr + total_offset;
|
||||
// global_load_lds takes a byte address; addr (const DataType*)
|
||||
// converts implicitly to const void*, no explicit cast needed.
|
||||
async_global_load_lds_dwordxn<vector_size>(smem, addr, pre_nop_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
@@ -1046,6 +1131,13 @@ struct tile_scatter_gather
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages)
|
||||
{
|
||||
static_assert(kUseGlobalLoad_,
|
||||
"global-load mode only; physical_pages_ is unused in SRD mode.");
|
||||
physical_pages_ = pages;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
{
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
|
||||
@@ -1139,7 +1231,29 @@ struct tile_scatter_gather
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// Scatter/gather offsets for each element, set by update_page_idx().
|
||||
// SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord).
|
||||
// page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base)
|
||||
// page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset)
|
||||
// Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only.
|
||||
// Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord
|
||||
PageIdxArray page_idx_;
|
||||
|
||||
// Physical page indices for global load mode (kUseGlobalLoad=true only).
|
||||
// Maps each gather element to its physical page in a paged memory pool.
|
||||
// Updated via update_physical_pages() before each load call.
|
||||
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
|
||||
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
|
||||
physical_pages_;
|
||||
|
||||
// Page stride in elements for global load mode (kUseGlobalLoad=true only).
|
||||
// physical_pages_[i] * page_stride_elements_ gives the page base offset in elements.
|
||||
// Set at construction time via the make_tile_scatter_gather overload that
|
||||
// takes bool_constant<kUseGlobalLoad>; immutable thereafter.
|
||||
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
|
||||
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
|
||||
page_stride_elements_;
|
||||
|
||||
ValidArray valids_;
|
||||
|
||||
// this contains:
|
||||
@@ -1178,7 +1292,8 @@ template <typename TensorView_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim,
|
||||
index_t NumCoord,
|
||||
index_t... YsGatherDims>
|
||||
index_t... YsGatherDims,
|
||||
bool UseGlobalLoad = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
@@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim>,
|
||||
number<NumCoord>,
|
||||
sequence<YsGatherDims...>)
|
||||
sequence<YsGatherDims...>,
|
||||
bool_constant<UseGlobalLoad> = {},
|
||||
index_t page_stride_elements = 0)
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
@@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord,
|
||||
sequence<YsGatherDims...>>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
sequence<YsGatherDims...>,
|
||||
UseGlobalLoad>{tensor_view,
|
||||
window_lengths,
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
nullptr,
|
||||
page_stride_elements};
|
||||
}
|
||||
|
||||
// Legacy overload (compatible with original API)
|
||||
// Legacy overload (compatible with original API, kUseGlobalLoad=false)
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -1227,6 +1350,42 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
// Overload with kUseGlobalLoad (simple, used by K cache).
|
||||
// page_stride_elements is forwarded to the constructor; required (non-zero)
|
||||
// when UseGlobalLoad=true so that physical_pages_[i] * page_stride_elements_
|
||||
// produces a valid address. Defaulting to 0 keeps SRD-mode call sites unchanged
|
||||
// (page_stride_elements_ is unread in SRD mode).
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
bool UseGlobalLoad>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
bool_constant<UseGlobalLoad>,
|
||||
index_t page_stride_elements = 0)
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
0,
|
||||
1,
|
||||
sequence<0>,
|
||||
UseGlobalLoad>{tensor_view,
|
||||
window_lengths,
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
nullptr,
|
||||
page_stride_elements};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
|
||||
@@ -12,6 +12,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// `always_false_v<T...>` — a value-template that is always `false` but whose
|
||||
// evaluation is deferred until template instantiation. The canonical use is
|
||||
// inside the `else` arm of an `if constexpr` chain or under an arch-gated
|
||||
// `#if` to fire a `static_assert` ONLY when the offending instantiation is
|
||||
// actually requested, e.g.:
|
||||
//
|
||||
// if constexpr (...) { ... }
|
||||
// else { static_assert(always_false_v<T>, "unsupported T"); }
|
||||
//
|
||||
// A bare `static_assert(false, ...)` would fire at template-definition
|
||||
// parse time on conforming compilers, breaking the whole TU.
|
||||
template <typename...>
|
||||
inline constexpr bool always_false_v = false;
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
@@ -55,6 +56,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines.
|
||||
// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool)
|
||||
// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache)
|
||||
enum class BlockAttentionKVCacheLoadModeEnum
|
||||
{
|
||||
BUFFER_LOAD = 0,
|
||||
GLOBAL_LOAD_LDS = 1,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -32,6 +32,83 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// A helper struct for detecting n0loop
|
||||
template <typename T, typename = void>
|
||||
struct has_n0loop_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_n0loop_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting ignore_fast_exp2 flag
|
||||
template <typename T, typename = void>
|
||||
struct has_ignore_fast_exp2_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2;
|
||||
// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface
|
||||
// is able to avoid passing an in-correct scale_s parameter to the kernel layer
|
||||
template <typename T>
|
||||
struct has_ignore_fast_exp2_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
|
||||
T::kIgnoreFastExp2>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
|
||||
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
|
||||
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
|
||||
template <typename T, typename = void>
|
||||
struct has_naive_hdim_load_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_naive_hdim_load_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
|
||||
T::kIsNaiveHDimLoad>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting kUseTrLoad
|
||||
template <typename T, typename = void>
|
||||
struct has_use_trload_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_use_trload_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdKernel
|
||||
{
|
||||
@@ -77,13 +154,14 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
|
||||
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvailable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvailable = !kUseTrLoad;
|
||||
#endif
|
||||
|
||||
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
@@ -444,7 +522,9 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
detail::ignore_fast_exp2_v<FmhaPipeline>
|
||||
? scale_s
|
||||
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale_s,
|
||||
#endif
|
||||
@@ -897,7 +977,9 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
detail::ignore_fast_exp2_v<FmhaPipeline>
|
||||
? scale_s
|
||||
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale_s,
|
||||
#endif
|
||||
@@ -1039,6 +1121,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
const void* seqstart_v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1097,6 +1180,7 @@ struct FmhaFwdKernel
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
seqstart_v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1158,6 +1242,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
const void* seqstart_v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1216,6 +1301,7 @@ struct FmhaFwdKernel
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
seqstart_v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1602,6 +1688,10 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
|
||||
? FmhaPipeline::kQKHeaddim
|
||||
: FmhaPipeline::kSubQKHeaddim;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1612,10 +1702,10 @@ struct FmhaFwdKernel
|
||||
number<1>{});
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1634,10 +1724,21 @@ struct FmhaFwdKernel
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -1649,18 +1750,29 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1683,17 +1795,28 @@ struct FmhaFwdKernel
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
return make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{});
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
|
||||
else
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
auto k_dram_window = [&]() {
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
@@ -1843,7 +1966,10 @@ struct FmhaFwdKernel
|
||||
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
|
||||
{
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
}
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
@@ -2826,7 +2952,10 @@ struct FmhaFwdKernel
|
||||
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
|
||||
{
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
}
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
@@ -134,7 +135,8 @@ template <typename IndexArrayType,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
index_t kVectorSize,
|
||||
bool kUseGlobalLoad_ = false>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
@@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// K cache: per-token lookup
|
||||
// Each token may be on a different page, so we use physical_pages[k0] for each.
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
// Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_):
|
||||
//
|
||||
// Case 1: kPageBlockSize >= kN0
|
||||
// SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller).
|
||||
// Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident).
|
||||
// This function writes within-page offset only.
|
||||
//
|
||||
// Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_
|
||||
// SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full
|
||||
// 64-bit address is computed by tile_scatter_gather::load() in
|
||||
// include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ +
|
||||
// page_stride_elements_. This function writes within-page offset only.
|
||||
//
|
||||
// Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true)
|
||||
// SRD base is the entire KV buffer; the only place to encode page identity
|
||||
// is the voffset itself. This function writes the FULL offset:
|
||||
// page * stride_page_block + within_page
|
||||
// Limited to <2GB total KV bytes by 32-bit voffset hardware width.
|
||||
//
|
||||
// Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_
|
||||
// Not emitted by codegen. Backstop static_assert in
|
||||
// BlockFmhaBatchPrefillPipelineQRKSVSAsync.
|
||||
constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_;
|
||||
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
|
||||
// Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT)
|
||||
const index_t within_page = [&]() {
|
||||
if constexpr(!kIsKcache && kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
kv_offset_vec[k0] =
|
||||
physical_page * stride_page_block + token_idx_in_page * stride_token;
|
||||
return token_idx_in_page * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
else // V cache
|
||||
{
|
||||
// V cache: use physical_pages[k0] for each token
|
||||
// physical_pages was already populated correctly by load_physical_pages(), handling:
|
||||
// - page_size=1: page_idx maps token_idx -> physical_page directly
|
||||
// - V tile crosses pages: per-token page lookup
|
||||
// - V tile in single page: lane0 lookup with broadcast to all lanes
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
}();
|
||||
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// SRD + page_size < kN0: add page base to form complete voffset for buffer_load.
|
||||
//
|
||||
// 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF
|
||||
// microcode format), so this branch is only reachable when total KV bytes fit in
|
||||
// INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit
|
||||
// global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling
|
||||
// because the hardware truncates voffset regardless.
|
||||
if constexpr(kNeedFullOffset)
|
||||
{
|
||||
kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = within_page;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
@@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
|
||||
static constexpr index_t kVectorSize = Problem::kVectorSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
// Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V
|
||||
// tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD
|
||||
// buffer_load_*. The enum is named at the trait/Problem level; internally we
|
||||
// derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits
|
||||
// GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop.
|
||||
static constexpr auto kKVLoadMode = Problem::kKVLoadMode;
|
||||
static constexpr bool kUseGlobalLoad =
|
||||
(kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS);
|
||||
static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0),
|
||||
"GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; "
|
||||
"codegen should not emit this instantiation otherwise.");
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
@@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
k_offsets,
|
||||
bool_constant<kUseGlobalLoad>{},
|
||||
page_stride_k);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
{
|
||||
k_dram_window.update_physical_pages(k_physical_pages);
|
||||
}
|
||||
k_dram_window.init_raw();
|
||||
|
||||
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
|
||||
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
|
||||
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
|
||||
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
|
||||
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
|
||||
// addressing.
|
||||
auto rebase_k_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
@@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
// Limit SRD num_records to one page worth of elements.
|
||||
// Without this, the SRD claims validity for [page_ptr, page_ptr +
|
||||
// full_buffer_size), which extends far beyond the allocated buffer when rebased to
|
||||
// high pages. On gfx950, the hardware may validate the full SRD range against page
|
||||
// table permissions, causing faults on freed/protected memory beyond the buffer.
|
||||
window.set_bottom_tensor_view_buffer_size(page_stride_k);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
// SRD rebasing for V: only for page_size >= kN0 (all threads on same page).
|
||||
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
|
||||
// addressing.
|
||||
auto rebase_v_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// readfirstlane: make physical_page provably wave-uniform so the
|
||||
// resulting SRD lands in SGPRs (required by buffer load instructions).
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
const auto* base_ptr =
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
window.set_bottom_tensor_view_buffer_size(page_stride_v);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
// Initial K SRD rebase
|
||||
// Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead)
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
@@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(v_physical_pages_k2,
|
||||
stride_v,
|
||||
page_stride_v,
|
||||
v_coord,
|
||||
v_offsets_k2,
|
||||
current_seq_k);
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(v_physical_pages_k2,
|
||||
stride_v,
|
||||
page_stride_v,
|
||||
v_coord,
|
||||
v_offsets_k2,
|
||||
current_seq_k);
|
||||
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
@@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
}
|
||||
|
||||
// v_offsets semantics — see the four-case addressing-strategy block above
|
||||
// kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda:
|
||||
// Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD.
|
||||
// Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed
|
||||
// by tile_scatter_gather::load() from
|
||||
// physical_pages_.
|
||||
// Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset):
|
||||
// FULL offset (page * stride + within),
|
||||
// carried in the 32-bit voffset (<2GB cap).
|
||||
};
|
||||
|
||||
// Prefetch V physical pages early to hide buffer load latency
|
||||
@@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
v_offsets,
|
||||
number<1>{}, // HsGatherDim
|
||||
number<1>{}, // NumCoord
|
||||
VPageIndexYDims);
|
||||
VPageIndexYDims,
|
||||
bool_constant<kUseGlobalLoad>{},
|
||||
page_stride_v);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
{
|
||||
v_dram_window.update_physical_pages(v_physical_pages);
|
||||
}
|
||||
|
||||
// Initial V SRD rebase
|
||||
// Initial V SRD rebase. Single source of truth: rebase_v_window's own
|
||||
// `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3.
|
||||
// Do not re-add an outer guard here — it would duplicate the inner check
|
||||
// and drift if the lambda's gating condition ever changes.
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
|
||||
// Save the *current* tile's V physical pages into v_dram_window before
|
||||
// prefetch_v_physical_pages overwrites the v_physical_pages buffer with the
|
||||
// *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read
|
||||
// physical_pages_ from the window. Encapsulating the save+prefetch pair
|
||||
// here makes the ordering invariant unmissable when a fourth prefetch site
|
||||
// is added later.
|
||||
auto save_and_prefetch_v_pages = [&](auto k_loop_start) {
|
||||
if constexpr(kUseGlobalLoad)
|
||||
v_dram_window.update_physical_pages(v_physical_pages);
|
||||
prefetch_v_physical_pages(k_loop_start);
|
||||
};
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
@@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
// Prefetch V physical pages early - overlaps with GEMM0 computation
|
||||
prefetch_v_physical_pages(number<kK1>{});
|
||||
save_and_prefetch_v_pages(number<kK1>{});
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
@@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Prefetch V physical pages early - overlaps with softmax computation
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<2 * kK1>{});
|
||||
save_and_prefetch_v_pages(number<2 * kK1>{});
|
||||
}
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
@@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
v_dram_window,
|
||||
{0,
|
||||
kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
update_v_offsets(number<2 * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
@@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
// Update V offsets using previously prefetched physical pages
|
||||
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
@@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
|
||||
if constexpr(i_k1 + 1 < k1_loops - 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
|
||||
save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
k_dram_window.update_physical_pages(k_physical_pages);
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
// After sink→window transition (i_total_loops == num_sink_loop), V window
|
||||
|
||||
@@ -9,6 +9,52 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename DataType, index_t ElemPerThread>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 6 == 0)
|
||||
// return 6;
|
||||
if constexpr(ElemPerThread % 8 == 0)
|
||||
return 8;
|
||||
else if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 3 == 0)
|
||||
// return 3;
|
||||
if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
return 1;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
index_t kThreadBlockSize,
|
||||
index_t kHigherDimSize,
|
||||
index_t kLowerDimSize>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
|
||||
|
||||
return GetMaxVectorSize<DataType, ElemPerThread>();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
@@ -117,6 +163,12 @@ struct BlockFmhaBatchPrefillPipelineProblem
|
||||
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
|
||||
"kPageBlockSize must be power of two");
|
||||
|
||||
// KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via
|
||||
// 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the
|
||||
// <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's
|
||||
// existing TwoGB convention.
|
||||
static constexpr auto kKVLoadMode = Traits_::kKVLoadMode;
|
||||
|
||||
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
|
||||
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,861 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static_assert(sizeof(KDataType) == sizeof(VDataType) &&
|
||||
alignof(KDataType) == alignof(VDataType),
|
||||
"K and V share the same LDS region; their element types must have identical "
|
||||
"size and alignment.");
|
||||
|
||||
static constexpr bool kUseN0Loop = true;
|
||||
static constexpr bool kIgnoreFastExp2 = true;
|
||||
static constexpr bool kIsNaiveHDimLoad = true;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kN0Sub =
|
||||
BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
// since this pipeline is only used by the inference path of xformers, the Dropout function is
|
||||
// not well tested with the pipeline, so here we have Dropout disabled
|
||||
static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async_whole_k_prefetch_trload";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& /* unused */,
|
||||
const AttentionVariantParams& /* unused */,
|
||||
const BlockIndices& /* unused */,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
// xformers path does not require the pipeline to output random values for host
|
||||
// verification, since a separate kernel is used to generate random values
|
||||
ignore = randval_dram_block_window_tmp;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
// usually kN0 is 128, kN0Sub/kK1 is 32/16
|
||||
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
|
||||
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
|
||||
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
Policy::template IsPreloadWholeNextIterationK<Problem>();
|
||||
|
||||
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
|
||||
// prefetch two v_tiles during Gemm0
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_assert(NumPrefetchV >= 2);
|
||||
};
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
|
||||
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
|
||||
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
return o_acc;
|
||||
};
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
auto k_tiles = [&]() {
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
return statically_indexed_array<k_tile_type, n0_loops>{};
|
||||
else
|
||||
return statically_indexed_array<k_tile_type, 2>{};
|
||||
}();
|
||||
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start},
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
// assuming no random values need be saved, this is true when the pipeline is called from
|
||||
// xformers, since we have a separate kernel to generated random values
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to pass a null_randval_dram and tile window to the BlockDropout operator to
|
||||
// make it works
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64
|
||||
{
|
||||
if(seqlen_k_curr == seqlen_k_start) // at first iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
else // at intermediate and last iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
// prefetch k_tile for next iteration
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
else // only preload one unroll of K for next iteration, used when kM0=128
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 2)
|
||||
{
|
||||
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 >= n0_loops - 2)
|
||||
{
|
||||
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x000000001);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto y) {
|
||||
x += type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
pcomp_tile(i_j_idx) *= scale_s;
|
||||
position_encoding.update(pcomp_tile(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(pcomp_tile, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p =
|
||||
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - NumPrefetchV)
|
||||
{
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
|
||||
{
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]),
|
||||
partition_index);
|
||||
};
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
// store lse
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
ignore = sink_v;
|
||||
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
ignore = sink_v;
|
||||
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
|
||||
@@ -57,7 +57,7 @@ struct TileFmhaShape
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
@@ -58,7 +59,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
|
||||
BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
|
||||
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
kPadSeqLenK_,
|
||||
kPadHeadDimQ_,
|
||||
@@ -76,6 +79,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static constexpr auto kKVLoadMode = kKVLoadMode_;
|
||||
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
|
||||
"Batch prefill only supports vectorized or linear KV cache layout.");
|
||||
|
||||
@@ -1685,7 +1685,7 @@ struct MoeSortingMultiPhaseKernel_P0_v1
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
if(eid < kargs.num_experts)
|
||||
if(eid < kargs.num_experts && eid >= 0)
|
||||
{
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2PrefetchK
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
static_assert(NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(I0),
|
||||
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
|
||||
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(kIter < KIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
|
||||
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
|
||||
b_warp_tensors[number<kIter + 1>{}] =
|
||||
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
if constexpr(kIter == 0)
|
||||
{
|
||||
// warp GEMM
|
||||
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
};
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
|
||||
static_assert(NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,239 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2PrefetchN
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// construct from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
|
||||
BDataType>::TransposedDstrEncode{};
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
|
||||
make_static_tile_distribution(b_warp_dstr_encode));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
|
||||
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
@@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
|
||||
@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -232,6 +246,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loa
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -32,6 +32,8 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user