Merge branch 'develop' into users/yiding12/fmha-bwd-workspace

This commit is contained in:
Yi DING
2026-04-27 15:07:41 +08:00
committed by GitHub
50 changed files with 5216 additions and 1120 deletions

View File

@@ -10,27 +10,31 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \
sudo mkdir /home/jenkins/workspace && \
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
cd rocm-libraries && \
mkdir rocm-libraries && cd rocm-libraries && \
git init -q && \
git remote add origin https://github.com/ROCm/rocm-libraries.git && \
git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/composablekernel && \
git checkout "$CK_AITER_BRANCH" && \
git checkout FETCH_HEAD && \
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
mv projects/composablekernel ../ck && \
cd ../ck && rm -rf ../rocm-libraries && \
git init && \
git init -b "$LOCAL_BRANCH" && \
git config user.name "assistant-librarian[bot]" && \
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
git branch -m "$CK_AITER_BRANCH" && git add -A && \
git add -A && \
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \
else \
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \
LOCAL_BRANCH="$CK_AITER_BRANCH" ; \
fi && \
cd /home/jenkins/workspace && rm -rf aiter && \
git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
cd aiter && \
rm -rf 3rdparty/composable_kernel/ && \
git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \
git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \
python3 setup.py develop && \
groupadd -g 1001 jenkins && \
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \

View File

@@ -12,27 +12,31 @@ RUN set -x ; \
sudo mkdir /home/jenkins/workspace && \
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
cd rocm-libraries && \
mkdir rocm-libraries && cd rocm-libraries && \
git init -q && \
git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/composablekernel && \
git checkout "$CK_FA_BRANCH" && \
git checkout FETCH_HEAD && \
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
mv projects/composablekernel ../ck && \
cd ../ck && rm -rf ../rocm-libraries && \
git init && \
git init -b "$LOCAL_BRANCH" && \
git config user.name "assistant-librarian[bot]" && \
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
git branch -m "$CK_FA_BRANCH" && git add -A && \
git add -A && \
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \
else \
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \
LOCAL_BRANCH="$CK_FA_BRANCH" ; \
fi && \
cd /home/jenkins/workspace && rm -rf flash-attention && \
git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \
cd flash-attention && \
rm -rf csrc/composable_kernel/ && \
git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \
groupadd -g 1001 jenkins && \
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \

17
Jenkinsfile vendored
View File

@@ -840,8 +840,10 @@ def cmake_build(Map conf=[:]){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_*.log"
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
dir("projects/composablekernel"){
archiveArtifacts "perf_fmha_*.log"
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
@@ -918,7 +920,7 @@ def Build_CK(Map conf=[:]){
sh "projects/composablekernel/script/run_inductor_tests.sh"
}
// run performance tests, stash the logs, results will be processed on the master node
dir("projects/composablekernel/script"){
dir("projects/composablekernel/script"){
if (params.RUN_PERFORMANCE_TESTS){
if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){
// run full tests on gfx90a or gfx942
@@ -1017,6 +1019,13 @@ def process_results(Map conf=[:]){
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
}
try{
unstash "perf_fmha_log_gfx950"
}
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}."
}
}
if (params.BUILD_INSTANCES_ONLY){
// unstash deb packages
@@ -1191,7 +1200,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME
CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME)
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''

View File

@@ -108,28 +108,35 @@ bool run_grouped_conv_fwd(bool do_verification,
if(do_verification)
{
Tensor<AccDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
c_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);
PassThrough{});
ref_invoker.Run(ref_argument);
out_host.ForEach([&](auto&, auto idx)
{
out_element_op(out_host(idx), c_host(idx));
});
out_device_buf.FromDevice(out_device.mData.data());
pass &=

View File

@@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import (
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.arch import ArchTrait
from codegen.utils import update_file
# Architecture trait for kernels requiring global_load_lds (CDNA3+).
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
CDNA3_PLUS_ARCH = ArchTrait(
"cdna3_plus",
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
@@ -34,6 +42,10 @@ DTYPE_BITS = {
"bf8": 8,
}
# Element size in bytes per dtype, used by the auto-generated dispatcher to
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
@@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
KV_LOAD_MODE_ENUM_MAP = {
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
@@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
"""
FMHA_FWD_KERNEL_BODY = """
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_sink},
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
{F_kv_lookup_table},
{F_kv_load_mode}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
#include <iostream>
@@ -140,10 +159,13 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
"""
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdint>
#include <cstdio>
namespace {{
@@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
constexpr int kElementBytes = {F_element_bytes};
{F_hdim_case}
}}
"""
@@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
+ ("-gload" if self.use_global_load else "-bload")
)
@property
@@ -481,6 +506,7 @@ class FmhaFwdApiPool:
],
F_page_size=trait.page_size,
F_sink=BOOL_MAP[trait.sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -488,7 +514,10 @@ class FmhaFwdApiPool:
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
F_if=if_i,
F_dtype=dtype,
F_element_bytes=DTYPE_BYTES[dtype],
F_hdim_case=per_hdim_case,
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
@@ -539,6 +568,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
@property
def template(self) -> str:
@@ -588,6 +618,10 @@ class FmhaFwdKernel:
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
if self.F_use_global_load
else "true",
)
@property
@@ -595,6 +629,7 @@ class FmhaFwdKernel:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ ("gload_" if self.F_use_global_load else "bload_")
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -632,6 +667,7 @@ class FmhaFwdKernel:
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
use_global_load=self.F_use_global_load,
)
@@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory):
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
targets: Optional[List[str]] = None
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
targets: Optional[List[str]] = None,
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
@@ -837,6 +876,25 @@ def get_fwd_blobs(
api_pool.register_traits(k.api_trait())
gen.append(k)
# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
# (slower, handles >2GB).
if page_size < tile.F_bn0:
k_global_load = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
F_use_global_load=True,
)
api_pool.register_traits(k_global_load.api_trait())
gen.append(k_global_load)
return (api_pool, gen)
@@ -856,7 +914,9 @@ def write_blobs(
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
api_pool, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
@@ -871,7 +931,9 @@ def list_blobs(
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
_, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

View File

@@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
// Inputs are taken as plain integers so the helper has no template parameter
// and can be called from each codegen-emitted dispatcher arm with the arm's
// compile-time kN0 / element_bytes substituted as constants.
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
ck_tile::index_t kN0,
ck_tile::index_t num_total_pages,
ck_tile::index_t batch_stride_k,
ck_tile::index_t element_bytes)
{
// Promote every operand to long_index_t so overflow is impossible regardless
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
// * batch_stride_k * element_bytes` only works because of left-to-right
// associativity — a future reorder of the operands would silently truncate.
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
static_cast<ck_tile::long_index_t>(batch_stride_k) *
static_cast<ck_tile::long_index_t>(element_bytes);
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
}
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
@@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
@@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};

View File

@@ -33,8 +33,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
bool LdsScalarLoadToVgpr = false>
bool TransposeC = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
@@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
LdsScalarLoadToVgpr ? 1 : A_K1,
ALdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
@@ -399,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
LdsScalarLoadToVgpr ? 1 : B_K1,
BLdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;
AThreadCopy a_thread_copy_;

View File

@@ -32,12 +32,13 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
bool DirectLoad = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(LdsScalarLoadToVgpr)
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
@@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
LdsScalarLoadToVgpr>{};
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{

View File

@@ -747,7 +747,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPacks,
bool LdsScalarLoadToVgpr = false>
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
@@ -772,7 +773,8 @@ template <index_t BlockSize,
index_t NRepeat,
index_t KPack,
// ,bool TransposeC //disable transposec right now...
bool LdsScalarLoadToVgpr>
bool ALdsScalarLoadToVgpr,
bool BLdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -793,7 +795,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
MRepeat,
NRepeat,
KPack,
LdsScalarLoadToVgpr>
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -814,7 +817,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NRepeat,
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
@@ -837,7 +841,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NRepeat,
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>;
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;

View File

@@ -19,6 +19,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
@@ -856,6 +857,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
input_right_pads_{input_right_pads}
{
k_batch_ = split_k;
k_batch_ = clamp_gemm_k_batch(k_batch_);
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(

View File

@@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit
k_batch_ = split_k;
}
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
if constexpr(IsTwoStageNeeded)
{

View File

@@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
const auto descs =
conv_to_gemm_transformer

View File

@@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
const auto descs =
conv_to_gemm_transformer

View File

@@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
const auto descs =
conv_to_gemm_transformer_v2

View File

@@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
// Create initial descriptors with hack=false to check compactness
const auto descs_initial =

View File

@@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,

View File

@@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes are divisible by k_batch

View File

@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
static constexpr bool ALdsScalarLoadToVgpr =
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
static constexpr bool BLdsScalarLoadToVgpr =
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
// Note: Direct load use layout to create proper block and mmtile descriptor
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
std::conditional_t<DirectLoad,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor>,
std::conditional_t<DirectLoad,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor>,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
DirectLoad>;
DirectLoad,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -625,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
{
k_batch_ = split_k;
}
k_batch_ = clamp_gemm_k_batch(k_batch_);
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes match product of dimensions

View File

@@ -162,6 +162,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
id_off += grid_size_grp;
id_local += grid_size_grp;
block_sync_lds();
}
}
#else

View File

@@ -136,6 +136,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU)
id_off += grid_size_grp;
id_local += grid_size_grp;
block_sync_lds();
}
}
#else

View File

@@ -13,6 +13,13 @@ namespace ck {
namespace tensor_operation {
namespace device {
/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in
/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)).
inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept
{
return k_batch < 1 ? index_t{1} : k_batch;
}
struct DeviceProperties
{
DeviceProperties()
@@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index
const int max_capacity = max_occupancy * device_properties.num_cu_;
ck::index_t k_batch = 1;
if(grid_size <= 0)
{
return k_batch;
}
const auto optimal_split =
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
if(optimal_split > 1)

View File

@@ -66,7 +66,9 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false>
bool DirectLoad = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct GridwiseGemm_xdl_cshuffle_conv_v3
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
@@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
{
if constexpr(!DirectLoad)
{
return desc;
}
else
{
const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
const index_t MN = desc.GetLength(I1);
const auto desc_unmerged = transform_tensor_descriptor(
desc,
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
make_pass_through_transform(MN),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
const auto desc_permuted = transform_tensor_descriptor(
desc_unmerged,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(MN, K0Number)),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
return transform_tensor_descriptor(
desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
make_pass_through_transform(MN),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
template <index_t MNXdlPerWave,
index_t MNWaves,
index_t MNPerXdl,
bool IsKContinous,
typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
if constexpr(DirectLoad && IsKContinous)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
constexpr auto desc = transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
return transform_tensor_descriptor(
desc,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
else
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
}
template <typename ABlockDesc_AK0_M_AK1>
@@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
return MakeGemmMmaTileDescriptor<MXdlPerWave,
MWaves,
MPerXdl,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
@@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave,
NWaves,
NPerXdl,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value>(
BBlockDesc_BK0_N_BK1{});
}
struct Problem
@@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
}
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
@@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
}
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
@@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
// or N dimension)
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
// static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad,
LdsScalarLoadToVgpr>())>;
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>())>;
template <typename DeviceArch>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
@@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1)
const index_t k_id = 0,
const index_t k_batch = 1,
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
@@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
make_multi_index(static_cast<index_t>(blockIdx.x)));
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
@@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
ABlockTransferSrcScalarPerVector >
(a_grid_desc_ak0_m_ak1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
@@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
BBlockTransferSrcScalarPerVector >
(b_grid_desc_bk0_n_bk1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
@@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1)
const index_t k_id = 0,
const index_t k_batch = 1,
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
@@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
make_multi_index(static_cast<index_t>(blockIdx.x)));
make_multi_index(static_cast<index_t>(block_idx_x)));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
@@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
ABlockTransferSrcScalarPerVector >
(a_grid_desc_ak0_m_ak1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
@@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
BBlockTransferSrcScalarPerVector >
(b_grid_desc_bk0_n_bk1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{

View File

@@ -21,6 +21,10 @@ template <index_t NDimSpatial,
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
struct TransformConvBwdWeightToGemm
{
// Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors).
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};

View File

@@ -31,6 +31,11 @@ template <index_t NDimSpatial,
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
struct TransformConvBwdWeightToGemmV2
{
// Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in
// integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime).
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};

View File

@@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// Flat async load from global memory to LDS using 64-bit global addressing.
// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds
// INT32_MAX (2GB) byte offset on the SRD voffset path.
//
// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!!
//
// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3:
// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`).
// M0 does NOT appear as an operand of these instructions or of the inline
// asm below — the compiler cannot see the dependency. Caller must:
//
// 1. Initialize M0 once before the load loop:
// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));`
// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to
// enforce this. Direct VALU writes to M0 are illegal.
//
// 2. Advance M0 between successive issues:
// `m0_inc_with_memory(size_per_issue);`
// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path
// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently
// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses
// M0[15:0] as a raw byte offset).
//
// 3. Never bundle `m0_inc_with_memory` and the next call to this
// function into a single inline asm. The compiler auto-inserts a
// hazard NOP between an SALU write to M0 and the consuming
// `global_load_lds_*`; bundling bypasses that and may read stale M0.
//
// The "memory" clobber on this asm is load-bearing: it prevents the
// compiler from reordering this load across other M0-touching helpers
// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered).
//
// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950):
// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000
// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both
// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but
// supported by the LLVM AMDGPU backend.
//
// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series).
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void
async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant<pre_nop> = {})
{
#if !defined(__gfx94__) && !defined(__gfx950__)
static_assert(always_false_v<integral_constant<unsigned, num_dwords>>,
"global_load_lds requires CDNA3+ (gfx940/gfx950). "
"Ensure kKVLoadMode is BUFFER_LOAD on this architecture.");
#endif
static_assert(num_dwords == 1 || num_dwords == 4,
"global_load_lds supports num_dwords == 1 or 4 only "
"(2 dwords does not exist on any supported arch; "
"3 dwords only on CDNA4 and unused in FMHA pipeline)");
// Inline asm: only the global address is an explicit operand. The LDS
// destination is implicit via M0 (see contract above). `"=r"(smem)` is a
// SSA scheduling anchor only — `smem` is NOT written by this asm; the
// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`.
#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \
if constexpr(pre_nop) \
asm volatile("s_nop 4\n" instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/); \
else \
asm volatile(instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/);
if constexpr(num_dwords == 1)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword");
}
else if constexpr(num_dwords == 4)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4");
}
#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR
}
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<int8_t, N>

View File

@@ -45,9 +45,29 @@ template <typename BottomTensorView_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
typename YsGatherDims = sequence<0>>
typename YsGatherDims = sequence<0>,
bool kUseGlobalLoad_ = false>
struct tile_scatter_gather
{
static constexpr bool kUseGlobalLoad = kUseGlobalLoad_;
#if !defined(__gfx94__) && !defined(__gfx950__)
// global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950).
// On other architectures, kUseGlobalLoad must be false.
static_assert(!kUseGlobalLoad_,
"kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). "
"This kernel should not be instantiated on this architecture.");
#endif
// Empty placeholder used by the SRD instantiation so physical_pages_ and
// page_stride_elements_ occupy zero bytes there (combined with
// [[no_unique_address]] on the member declarations). Access sites are all
// inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD
// mode, so no caller needs to change.
struct gl_field_empty_t
{
};
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
@@ -233,15 +253,22 @@ struct tile_scatter_gather
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx,
const ValidArray& valids)
const ValidArray& valids,
index_t page_stride_elements = 0)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
physical_pages_{},
page_stride_elements_{},
valids_{valids},
pre_computed_coords_{}
{
if constexpr(kUseGlobalLoad_)
{
page_stride_elements_ = page_stride_elements;
}
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
@@ -357,6 +384,34 @@ struct tile_scatter_gather
bottom_tensor_view_.buf_.p_data_ = data;
}
// Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for
// SRD num_records control. Use to set max range when SRD is rebased per-tile
// (page_size >= kN0 path): each rebased SRD only needs to cover one page; without
// this the SRD claims validity for memory beyond the allocated buffer, which can
// fault on gfx950 page-table validation.
//
// Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element
// count and is divided by PackedSize before being stored. For PackedSize=1
// (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4)
// skipping it would over-report num_records by 2x and silently mask OOB on SRD
// reads. batch_prefill currently does not exercise the packed-type path, but this
// setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must
// honor the same invariant the ctor enforces.
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size)
{
// Hint the optimizer that size is positive without inserting a runtime
// branch. Using <cassert> assert() here corrupted gfx950 batch_prefill
// output: the __assert_fail handler's SGPR pressure forced the K-SRD
// register window to be reused as scratch and scattered the SRD writes
// across two conditional branches, which gfx950's packed
// buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it
// via per-tile single-dword loads). __builtin_assume is hint-only —
// no branch, no scratch SGPRs, no codegen impact.
__builtin_assume(size > 0);
using BufType = remove_cvref_t<decltype(bottom_tensor_view_.buf_)>;
bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
@@ -458,7 +513,21 @@ struct tile_scatter_gather
// read from bottom tensor
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
if constexpr(kUseGlobalLoad_)
{
// Global load mode: 64-bit typed pointer arithmetic
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
const auto physical_page = physical_pages_[idx_gather];
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
const long_index_t total_offset =
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
coord_offset + page_offset;
const auto* addr = base_ptr + total_offset;
vector_t v;
__builtin_memcpy(&v, addr, sizeof(vector_t));
return v;
}
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
@@ -680,7 +749,23 @@ struct tile_scatter_gather
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
if constexpr(kUseGlobalLoad_)
{
// Global load mode: global_load_lds with 64-bit address
constexpr index_t vector_size =
sizeof(vector_t) / sizeof(uint32_t); // dwords per vector
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
const auto physical_page = physical_pages_[idx_gather];
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
const long_index_t total_offset =
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
coord_offset + page_offset;
const auto* addr = base_ptr + total_offset;
// global_load_lds takes a byte address; addr (const DataType*)
// converts implicitly to const void*, no explicit cast needed.
async_global_load_lds_dwordxn<vector_size>(smem, addr, pre_nop_);
}
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
@@ -1046,6 +1131,13 @@ struct tile_scatter_gather
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages)
{
static_assert(kUseGlobalLoad_,
"global-load mode only; physical_pages_ is unused in SRD mode.");
physical_pages_ = pages;
}
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
{
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
@@ -1139,7 +1231,29 @@ struct tile_scatter_gather
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// Scatter/gather offsets for each element, set by update_page_idx().
// SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord).
// page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base)
// page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset)
// Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only.
// Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord
PageIdxArray page_idx_;
// Physical page indices for global load mode (kUseGlobalLoad=true only).
// Maps each gather element to its physical page in a paged memory pool.
// Updated via update_physical_pages() before each load call.
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
physical_pages_;
// Page stride in elements for global load mode (kUseGlobalLoad=true only).
// physical_pages_[i] * page_stride_elements_ gives the page base offset in elements.
// Set at construction time via the make_tile_scatter_gather overload that
// takes bool_constant<kUseGlobalLoad>; immutable thereafter.
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
page_stride_elements_;
ValidArray valids_;
// this contains:
@@ -1178,7 +1292,8 @@ template <typename TensorView_,
typename StaticPageIndexArray_,
index_t HsGatherDim,
index_t NumCoord,
index_t... YsGatherDims>
index_t... YsGatherDims,
bool UseGlobalLoad = false>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
@@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim>,
number<NumCoord>,
sequence<YsGatherDims...>)
sequence<YsGatherDims...>,
bool_constant<UseGlobalLoad> = {},
index_t page_stride_elements = 0)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
@@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
std::nullptr_t,
HsGatherDim,
NumCoord,
sequence<YsGatherDims...>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
sequence<YsGatherDims...>,
UseGlobalLoad>{tensor_view,
window_lengths,
origin,
tile_distribution,
page_idx,
nullptr,
page_stride_elements};
}
// Legacy overload (compatible with original API)
// Legacy overload (compatible with original API, kUseGlobalLoad=false)
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1227,6 +1350,42 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
// Overload with kUseGlobalLoad (simple, used by K cache).
// page_stride_elements is forwarded to the constructor; required (non-zero)
// when UseGlobalLoad=true so that physical_pages_[i] * page_stride_elements_
// produces a valid address. Defaulting to 0 keeps SRD-mode call sites unchanged
// (page_stride_elements_ is unread in SRD mode).
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
bool UseGlobalLoad>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
bool_constant<UseGlobalLoad>,
index_t page_stride_elements = 0)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
0,
1,
sequence<0>,
UseGlobalLoad>{tensor_view,
window_lengths,
origin,
tile_distribution,
page_idx,
nullptr,
page_stride_elements};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,

View File

@@ -12,6 +12,20 @@
namespace ck_tile {
// `always_false_v<T...>` — a value-template that is always `false` but whose
// evaluation is deferred until template instantiation. The canonical use is
// inside the `else` arm of an `if constexpr` chain or under an arch-gated
// `#if` to fire a `static_assert` ONLY when the offending instantiation is
// actually requested, e.g.:
//
// if constexpr (...) { ... }
// else { static_assert(always_false_v<T>, "unsupported T"); }
//
// A bare `static_assert(false, ...)` would fire at template-definition
// parse time on conforming compilers, breaking the whole TU.
template <typename...>
inline constexpr bool always_false_v = false;
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <iostream>
namespace ck_tile {

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
@@ -55,6 +56,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines.
// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool)
// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache)
enum class BlockAttentionKVCacheLoadModeEnum
{
BUFFER_LOAD = 0,
GLOBAL_LOAD_LDS = 1,
};
} // namespace ck_tile

View File

@@ -32,6 +32,83 @@
namespace ck_tile {
namespace detail {
// A helper struct for detecting n0loop
template <typename T, typename = void>
struct has_n0loop_flag : std::false_type
{
};
template <typename T>
struct has_n0loop_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
// A helper struct for detecting ignore_fast_exp2 flag
template <typename T, typename = void>
struct has_ignore_fast_exp2_flag : std::false_type
{
};
// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2;
// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface
// is able to avoid passing an in-correct scale_s parameter to the kernel layer
template <typename T>
struct has_ignore_fast_exp2_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
T::kIgnoreFastExp2>> : std::true_type
{
};
template <typename T>
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
template <typename T, typename = void>
struct has_naive_hdim_load_flag : std::false_type
{
};
template <typename T>
struct has_naive_hdim_load_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
T::kIsNaiveHDimLoad>> : std::true_type
{
};
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
// A helper struct for detecting kUseTrLoad
template <typename T, typename = void>
struct has_use_trload_flag : std::false_type
{
};
template <typename T>
struct has_use_trload_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
} // namespace detail
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
{
@@ -77,13 +154,14 @@ struct FmhaFwdKernel
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
#if defined(__gfx950__)
static constexpr bool kIsAvailable = true;
#else
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -444,7 +522,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -897,7 +977,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -1039,6 +1121,7 @@ struct FmhaFwdKernel
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
const void* seqstart_v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -1097,6 +1180,7 @@ struct FmhaFwdKernel
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
seqstart_v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -1158,6 +1242,7 @@ struct FmhaFwdKernel
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
const void* seqstart_v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -1216,6 +1301,7 @@ struct FmhaFwdKernel
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
seqstart_v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -1602,6 +1688,10 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
? FmhaPipeline::kQKHeaddim
: FmhaPipeline::kSubQKHeaddim;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1612,10 +1702,10 @@ struct FmhaFwdKernel
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
@@ -1634,10 +1724,21 @@ struct FmhaFwdKernel
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -1649,18 +1750,29 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
if constexpr(!kUseTrLoad)
{
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadHeadDimV>{});
};
}
else
{
@@ -1683,17 +1795,28 @@ struct FmhaFwdKernel
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
auto k_dram_window = [&]() {
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
{0, 0});
}
else
{
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
}
}();
auto v_dram_window = make_tile_window(
v_dram,
@@ -1843,7 +1966,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{
@@ -2826,7 +2952,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
@@ -134,7 +135,8 @@ template <typename IndexArrayType,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
index_t kVectorSize,
bool kUseGlobalLoad_ = false>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
const index_t& stride_token,
const index_t& stride_page_block,
@@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// K cache: per-token lookup
// Each token may be on a different page, so we use physical_pages[k0] for each.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_):
//
// Case 1: kPageBlockSize >= kN0
// SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller).
// Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident).
// This function writes within-page offset only.
//
// Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_
// SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full
// 64-bit address is computed by tile_scatter_gather::load() in
// include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ +
// page_stride_elements_. This function writes within-page offset only.
//
// Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true)
// SRD base is the entire KV buffer; the only place to encode page identity
// is the voffset itself. This function writes the FULL offset:
// page * stride_page_block + within_page
// Limited to <2GB total KV bytes by 32-bit voffset hardware width.
//
// Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_
// Not emitted by codegen. Backstop static_assert in
// BlockFmhaBatchPrefillPipelineQRKSVSAsync.
constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_;
if constexpr(kPageBlockSize >= kN0)
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT)
const index_t within_page = [&]() {
if constexpr(!kIsKcache && kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
kv_offset_vec[k0] = token_idx_in_page * stride_token;
return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] =
physical_page * stride_page_block + token_idx_in_page * stride_token;
return token_idx_in_page * stride_token;
}
});
}
else // V cache
{
// V cache: use physical_pages[k0] for each token
// physical_pages was already populated correctly by load_physical_pages(), handling:
// - page_size=1: page_idx maps token_idx -> physical_page directly
// - V tile crosses pages: per-token page lookup
// - V tile in single page: lane0 lookup with broadcast to all lanes
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
}();
if constexpr(kPageBlockSize >= kN0)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = token_offset;
}
else
{
kv_offset_vec[k0] = token_idx_in_page * stride_token;
}
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else
{
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
}
}
});
}
// SRD + page_size < kN0: add page base to form complete voffset for buffer_load.
//
// 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF
// microcode format), so this branch is only reachable when total KV bytes fit in
// INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit
// global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling
// because the hardware truncates voffset regardless.
if constexpr(kNeedFullOffset)
{
kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page;
}
else
{
kv_offset_vec[k0] = within_page;
}
});
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
@@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
// Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V
// tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD
// buffer_load_*. The enum is named at the trait/Problem level; internally we
// derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits
// GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop.
static constexpr auto kKVLoadMode = Problem::kKVLoadMode;
static constexpr bool kUseGlobalLoad =
(kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS);
static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0),
"GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; "
"codegen should not emit this instantiation otherwise.");
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
@@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
k_offsets,
bool_constant<kUseGlobalLoad>{},
page_stride_k);
if constexpr(kUseGlobalLoad)
{
k_dram_window.update_physical_pages(k_physical_pages);
}
k_dram_window.init_raw();
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_k_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
@@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
window.set_bottom_tensor_view_data_ptr(page_ptr);
// Limit SRD num_records to one page worth of elements.
// Without this, the SRD claims validity for [page_ptr, page_ptr +
// full_buffer_size), which extends far beyond the allocated buffer when rebased to
// high pages. On gfx950, the hardware may validate the full SRD range against page
// table permissions, causing faults on freed/protected memory beyond the buffer.
window.set_bottom_tensor_view_buffer_size(page_stride_k);
window.init_raw();
}
};
// SRD rebasing for V: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_v_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
const auto* base_ptr =
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
window.set_bottom_tensor_view_data_ptr(page_ptr);
window.set_bottom_tensor_view_buffer_size(page_stride_v);
window.init_raw();
}
};
// Initial K SRD rebase
// Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead)
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
constexpr auto k_oob_ck = bool_constant<true>{};
@@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
kVectorSize,
kUseGlobalLoad>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
@@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
// v_offsets semantics — see the four-case addressing-strategy block above
// kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda:
// Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD.
// Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed
// by tile_scatter_gather::load() from
// physical_pages_.
// Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset):
// FULL offset (page * stride + within),
// carried in the 32-bit voffset (<2GB cap).
};
// Prefetch V physical pages early to hide buffer load latency
@@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_offsets,
number<1>{}, // HsGatherDim
number<1>{}, // NumCoord
VPageIndexYDims);
VPageIndexYDims,
bool_constant<kUseGlobalLoad>{},
page_stride_v);
if constexpr(kUseGlobalLoad)
{
v_dram_window.update_physical_pages(v_physical_pages);
}
// Initial V SRD rebase
// Initial V SRD rebase. Single source of truth: rebase_v_window's own
// `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3.
// Do not re-add an outer guard here — it would duplicate the inner check
// and drift if the lambda's gating condition ever changes.
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
// Save the *current* tile's V physical pages into v_dram_window before
// prefetch_v_physical_pages overwrites the v_physical_pages buffer with the
// *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read
// physical_pages_ from the window. Encapsulating the save+prefetch pair
// here makes the ordering invariant unmissable when a fourth prefetch site
// is added later.
auto save_and_prefetch_v_pages = [&](auto k_loop_start) {
if constexpr(kUseGlobalLoad)
v_dram_window.update_physical_pages(v_physical_pages);
prefetch_v_physical_pages(k_loop_start);
};
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
@@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
// Prefetch V physical pages early - overlaps with GEMM0 computation
prefetch_v_physical_pages(number<kK1>{});
save_and_prefetch_v_pages(number<kK1>{});
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
@@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages early - overlaps with softmax computation
if constexpr(k1_loops > 1)
{
prefetch_v_physical_pages(number<2 * kK1>{});
save_and_prefetch_v_pages(number<2 * kK1>{});
}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
@@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_dram_window,
{0,
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
@@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
// Update V offsets using previously prefetched physical pages
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
@@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
if constexpr(i_k1 + 1 < k1_loops - 1)
{
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{});
}
block_sync_lds();
@@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(kUseGlobalLoad)
k_dram_window.update_physical_pages(k_physical_pages);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
// After sink→window transition (i_total_loops == num_sink_loop), V window

View File

@@ -9,6 +9,52 @@
namespace ck_tile {
namespace detail {
template <typename DataType, index_t ElemPerThread>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
{
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 6 == 0)
// return 6;
if constexpr(ElemPerThread % 8 == 0)
return 8;
else if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else if constexpr(std::is_same_v<DataType, float>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 3 == 0)
// return 3;
if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else
return 1;
};
template <typename DataType,
index_t kThreadBlockSize,
index_t kHigherDimSize,
index_t kLowerDimSize>
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
{
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
return GetMaxVectorSize<DataType, ElemPerThread>();
}
} // namespace detail
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
@@ -117,6 +163,12 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
// KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via
// 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the
// <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's
// existing TwoGB convention.
static constexpr auto kKVLoadMode = Traits_::kKVLoadMode;
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;

View File

@@ -0,0 +1,861 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using CompDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true;
static_assert(kQLoadOnce == Policy::QLoadOnce);
static_assert(sizeof(KDataType) == sizeof(VDataType) &&
alignof(KDataType) == alignof(VDataType),
"K and V share the same LDS region; their element types must have identical "
"size and alignment.");
static constexpr bool kUseN0Loop = true;
static constexpr bool kIgnoreFastExp2 = true;
static constexpr bool kIsNaiveHDimLoad = true;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kN0Sub =
BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static_assert(Problem::kUseTrLoad == true, "Check failed!");
static constexpr bool kUseTrLoad = true;
// since this pipeline is only used by the inference path of xformers, the Dropout function is
// not well tested with the pipeline, so here we have Dropout disabled
static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)
{
return 2;
}
else if constexpr(kQKHeaddim == 64)
{
return 2;
}
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim == 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async_whole_k_prefetch_trload";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& /* unused */,
const AttentionVariantParams& /* unused */,
const BlockIndices& /* unused */,
void* smem_ptr,
DropoutType& dropout) const
{
// xformers path does not require the pipeline to output random values for host
// verification, since a separate kernel is used to generate random values
ignore = randval_dram_block_window_tmp;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr index_t n0_loops = kN0 / kN0Sub;
constexpr index_t k1_loops = kN0 / kK1;
// usually kN0 is 128, kN0Sub/kK1 is 32/16
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
// prefetch two v_tiles during Gemm0
if constexpr(!kPreloadWholeNextIterationK)
{
static_assert(NumPrefetchV >= 2);
};
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
// SaccBlockTile size is [kM0, kK1]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
auto q_tile = load_tile(q_dram_window);
using k_tile_type = decltype(load_tile(k_dram_window));
auto k_tiles = [&]() {
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, n0_loops>{};
else
return statically_indexed_array<k_tile_type, 2>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
if constexpr(!kPreloadWholeNextIterationK)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kN1>{}),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
const auto f_exp = [&](CompDataType x) {
if constexpr(std::is_same_v<CompDataType, float>)
{
return __expf(x);
}
else
{
return exp(x);
}
};
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kN0>{}),
{bias_origin.at(number<0>{}), seqlen_k_start},
Policy::template MakeBiasDramTileDistribution<Problem>());
// assuming no random values need be saved, this is true when the pipeline is called from
// xformers, since we have a separate kernel to generated random values
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to pass a null_randval_dram and tile window to the BlockDropout operator to
// make it works
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
do
{
// STAGE 1, Gemm_0 ( S = Q@K )
if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64
{
if(seqlen_k_curr == seqlen_k_start) // at first iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < n0_loops - 1)
{
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
// prefetch all k_tiles for next iteration
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
else // the iteration is also the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < n0_loops - 1)
{
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
};
}
else // at intermediate and last iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 == 0)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
// prefetch k_tile for next iteration
k_tiles[i_n0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
else // last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 == 0)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
};
}
}
else // only preload one unroll of K for next iteration, used when kM0=128
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_n0 < n0_loops - 2)
{
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 >= n0_loops - 2)
{
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
__builtin_amdgcn_sched_barrier(0x000000001);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
tile_elementwise_inout(
[&](auto& x, const auto y) {
x += type_convert<CompDataType>(bias_element_func(y));
},
pcomp_tile,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) *= scale_s;
position_encoding.update(pcomp_tile(i_j_idx), row, col);
});
});
}
else
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
q_origin.at(number<0>{}), seqlen_k_curr, number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(pcomp_tile, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
__builtin_amdgcn_sched_barrier(0x00000001);
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m;
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
__builtin_amdgcn_sched_barrier(0);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(kPreloadWholeNextIterationK)
{
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
}
else
{
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
};
__builtin_amdgcn_sched_barrier(0);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
});
}
else
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
});
}
});
auto rowsum_p =
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// adjust o_acc[] according to the update between m and m_old
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
l(i_idx) = rowsum_p[i_idx];
}
else
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
}
});
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
seqlen_k_curr += kN0;
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
{
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
if constexpr(i_k1 == k1_loops - NumPrefetchV)
{
if constexpr(!kPreloadWholeNextIterationK)
{
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
}
};
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
{
if constexpr(!kPreloadWholeNextIterationK)
{
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
}
};
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]),
partition_index);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
{
__builtin_amdgcn_s_barrier();
};
} while(seqlen_k_curr < seqlen_k_end);
// store lse
if constexpr(kStoreLSE)
{
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(m[i_idx] == -numeric<CompDataType>::infinity())
o_acc(i_j_idx) = 0.0f;
else
o_acc(i_j_idx) *= 1.0f / l[i_idx];
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
{
ignore = sink_v;
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
ignore = sink_v;
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,

View File

@@ -57,7 +57,7 @@ struct TileFmhaShape
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
@@ -58,7 +59,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
kPadSeqLenK_,
kPadHeadDimQ_,
@@ -76,6 +79,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
"Batch prefill only supports vectorized or linear KV cache layout.");

View File

@@ -1685,7 +1685,7 @@ struct MoeSortingMultiPhaseKernel_P0_v1
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
if(eid < kargs.num_experts)
if(eid < kargs.num_experts && eid >= 0)
{
if constexpr(Problem::LocalToken)
{

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchK
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I0),
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
if constexpr(kIter < KIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
b_warp_tensors[number<kIter + 1>{}] =
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
};
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
if constexpr(kIter == 0)
{
// warp GEMM
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
}
else
{
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
};
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
return make_static_tile_distribution(a_block_dstr_encode);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,239 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchN
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{}) =
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
};
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,243 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// construct from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
constexpr auto b_warp_dstr_encode =
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
BDataType>::TransposedDstrEncode{};
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
make_static_tile_distribution(b_warp_dstr_encode));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
__builtin_amdgcn_sched_barrier(0);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{}) =
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
};
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using BF8 = ck::bf8_t;
using F8 = ck::f8_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
op_ptrs);
@@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(

View File

@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
@@ -232,6 +246,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loa
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16

View File

@@ -32,6 +32,8 @@ add_instance_library(
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck