diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 8d6e995656..4fcebc9033 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -10,27 +10,31 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \ sudo mkdir /home/jenkins/workspace && \ cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \ if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \ - git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \ - cd rocm-libraries && \ + mkdir rocm-libraries && cd rocm-libraries && \ + git init -q && \ + git remote add origin https://github.com/ROCm/rocm-libraries.git && \ + git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \ git sparse-checkout init --cone && \ git sparse-checkout set projects/composablekernel && \ - git checkout "$CK_AITER_BRANCH" && \ + git checkout FETCH_HEAD && \ ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \ + LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \ mv projects/composablekernel ../ck && \ cd ../ck && rm -rf ../rocm-libraries && \ - git init && \ + git init -b "$LOCAL_BRANCH" && \ git config user.name "assistant-librarian[bot]" && \ git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \ - git branch -m "$CK_AITER_BRANCH" && git add -A && \ + git add -A && \ git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \ else \ - git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \ + git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \ + LOCAL_BRANCH="$CK_AITER_BRANCH" ; \ fi && \ cd /home/jenkins/workspace && rm -rf aiter && \ git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \ cd aiter && \ rm -rf 3rdparty/composable_kernel/ && \ - git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \ + git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \ python3 setup.py develop && \ groupadd -g 1001 jenkins && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ diff --git a/Dockerfile.fa b/Dockerfile.fa index 47643310bd..025bbd414e 100644 --- a/Dockerfile.fa +++ b/Dockerfile.fa @@ -12,27 +12,31 @@ RUN set -x ; \ sudo mkdir /home/jenkins/workspace && \ cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \ if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \ - git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \ - cd rocm-libraries && \ + mkdir rocm-libraries && cd rocm-libraries && \ + git init -q && \ + git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \ + git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \ git sparse-checkout init --cone && \ git sparse-checkout set projects/composablekernel && \ - git checkout "$CK_FA_BRANCH" && \ + git checkout FETCH_HEAD && \ ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \ + LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \ mv projects/composablekernel ../ck && \ cd ../ck && rm -rf ../rocm-libraries && \ - git init && \ + git init -b "$LOCAL_BRANCH" && \ git config user.name "assistant-librarian[bot]" && \ git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \ - git branch -m "$CK_FA_BRANCH" && git add -A && \ + git add -A && \ git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \ else \ - git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \ + git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \ + LOCAL_BRANCH="$CK_FA_BRANCH" ; \ fi && \ cd /home/jenkins/workspace && rm -rf flash-attention && \ git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \ cd flash-attention && \ rm -rf csrc/composable_kernel/ && \ - git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \ + git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \ MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \ groupadd -g 1001 jenkins && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ diff --git a/Jenkinsfile b/Jenkinsfile index 170e0bf432..8675c716e7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -840,8 +840,10 @@ def cmake_build(Map conf=[:]){ if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - archiveArtifacts "perf_fmha_*.log" - stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}" + dir("projects/composablekernel"){ + archiveArtifacts "perf_fmha_*.log" + stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}" + } } catch(Exception err){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." @@ -918,7 +920,7 @@ def Build_CK(Map conf=[:]){ sh "projects/composablekernel/script/run_inductor_tests.sh" } // run performance tests, stash the logs, results will be processed on the master node - dir("projects/composablekernel/script"){ + dir("projects/composablekernel/script"){ if (params.RUN_PERFORMANCE_TESTS){ if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){ // run full tests on gfx90a or gfx942 @@ -1017,6 +1019,13 @@ def process_results(Map conf=[:]){ catch(Exception err){ echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}." } + try{ + unstash "perf_fmha_log_gfx950" + } + catch(Exception err){ + echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}." + } + } if (params.BUILD_INSTANCES_ONLY){ // unstash deb packages @@ -1191,7 +1200,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_ 0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true 0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true 0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" -CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME +CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME) POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : '' diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 02228d7654..26c3165446 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -108,28 +108,35 @@ bool run_grouped_conv_fwd(bool do_verification, if(do_verification) { + Tensor c_host(out_g_n_k_wos_desc); + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + PassThrough>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, wei, - out_host, + c_host, conv_param.conv_filter_strides_, conv_param.conv_filter_dilations_, conv_param.input_left_pads_, conv_param.input_right_pads_, in_element_op, wei_element_op, - out_element_op); + PassThrough{}); ref_invoker.Run(ref_argument); + out_host.ForEach([&](auto&, auto idx) + { + out_element_op(out_host(idx), c_host(idx)); + }); + out_device_buf.FromDevice(out_device.mData.data()); pass &= diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 7c3efb9c18..8c006c09db 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import ( QSCALE_CHECK_MAP, QSCALE_MAP, ) +from codegen.arch import ArchTrait from codegen.utils import update_file +# Architecture trait for kernels requiring global_load_lds (CDNA3+). +# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic. +CDNA3_PLUS_ARCH = ArchTrait( + "cdna3_plus", + preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", +) + DTYPE_BITS = { "fp32": 32, "fp16": 16, @@ -34,6 +42,10 @@ DTYPE_BITS = { "bf8": 8, } +# Element size in bytes per dtype, used by the auto-generated dispatcher to +# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX). +DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} + K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} SUPPORTED_PAGE_SIZE = [1, 16, 1024] @@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = { "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", } +KV_LOAD_MODE_ENUM_MAP = { + False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD", + True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS", +} FMHA_BATCH_PREFILL_PIPELINE_MAP = { @@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT """ FMHA_FWD_KERNEL_BODY = """ +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_sink}, {F_page_size}, {F_kv_memory_layout}, - {F_kv_lookup_table}>; + {F_kv_lookup_table}, + {F_kv_load_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; #include @@ -140,10 +159,13 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) """ FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" FMHA_FWD_API = """ +#include #include namespace {{ @@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, """ FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + constexpr int kElementBytes = {F_element_bytes}; {F_hdim_case} }} """ @@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; return fmha_batch_prefill_(s, a); }} """ @@ -253,12 +276,14 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size + use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" + + ("-gload" if self.use_global_load else "-bload") ) @property @@ -481,6 +506,7 @@ class FmhaFwdApiPool: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -488,7 +514,10 @@ class FmhaFwdApiPool: ) if_i = "if" if i == 0 else "else if" per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + F_if=if_i, + F_dtype=dtype, + F_element_bytes=DTYPE_BYTES[dtype], + F_hdim_case=per_hdim_case, ) if not per_dtypes: # empty string we add some ignore to suppress warning in api @@ -539,6 +568,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size + F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def template(self) -> str: @@ -588,6 +618,10 @@ class FmhaFwdKernel: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load], + F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check + if self.F_use_global_load + else "true", ) @property @@ -595,6 +629,7 @@ class FmhaFwdKernel: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + + ("gload_" if self.F_use_global_load else "bload_") + self.F_tile.name + "_" + self.F_pipeline.name @@ -632,6 +667,7 @@ class FmhaFwdKernel: kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, page_size=self.F_page_size, + use_global_load=self.F_use_global_load, ) @@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory): def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl, - targets: Optional[List[str]] = None + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + targets: Optional[List[str]] = None, ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with @@ -837,6 +876,25 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) + # For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS + # variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD + # buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_* + # (slower, handles >2GB). + if page_size < tile.F_bn0: + k_global_load = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + F_use_global_load=True, + ) + api_pool.register_traits(k_global_load.api_trait()) + gen.append(k_global_load) + return (api_pool, gen) @@ -856,7 +914,9 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + api_pool, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -871,7 +931,9 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + _, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6c842def58..98e2df2e1e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -673,6 +673,33 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; +// Selects the KV-cache load mode for a batch-prefill dispatch arm. +// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile +// so per-page SRD is impossible, AND (b) the total KV-pool byte size +// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it. +// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest. +// Inputs are taken as plain integers so the helper has no template parameter +// and can be called from each codegen-emitted dispatcher arm with the arm's +// compile-time kN0 / element_bytes substituted as constants. +inline ck_tile::BlockAttentionKVCacheLoadModeEnum +fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, + ck_tile::index_t kN0, + ck_tile::index_t num_total_pages, + ck_tile::index_t batch_stride_k, + ck_tile::index_t element_bytes) +{ + // Promote every operand to long_index_t so overflow is impossible regardless + // of multiplication order. A bare `static_cast(num_total_pages) + // * batch_stride_k * element_bytes` only works because of left-to-right + // associativity — a future reorder of the operands would silently truncate. + const auto kv_pool_bytes = static_cast(num_total_pages) * + static_cast(batch_stride_k) * + static_cast(element_bytes); + return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) + ? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS + : ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD; +} + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -1457,7 +1484,9 @@ template + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ + bool TransposeC = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : A_K1, + ALdsScalarLoadToVgpr ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : B_K1, + BLdsScalarLoadToVgpr ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 461ca513f9..f1a093a7a8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -32,12 +32,13 @@ template + bool DirectLoad = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> constexpr auto BlockGemmPipeline_Selector() { // Supported for Direct Load and V1 - if constexpr(LdsScalarLoadToVgpr) + if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr) { static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); } @@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, - LdsScalarLoadToVgpr>{}; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 723ef9cd1e..6c5b2a266b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -747,7 +747,8 @@ template + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -772,7 +773,8 @@ template + bool ALdsScalarLoadToVgpr, + bool BLdsScalarLoadToVgpr> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> : BlockwiseGemmXdlops_pipeline_base + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> { using Base = BlockwiseGemmXdlops_pipeline_base; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..270d4e264c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const std::array gemm_kernel_args, + const index_t gemms_count, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if defined(__gfx9__) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if constexpr(GridwiseGemm::DirectLoadEnabled) + { +#if defined(__gfx950__) + const auto a_grid_desc_ak0_m_ak1_transformed = + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } +#endif + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + } +#else + ignore = karg; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; + +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + static_assert(std::is_same_v, "A not NGHWC"); + static_assert(std::is_same_v, "B not GKYXC"); + static_assert(std::is_same_v, "C not NGHWK"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(NumDTensor == 0, "Not supported"); + // static_assert(DirectLoad, "Not supported"); + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = false; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + // Dummy function just used to create an alias to Grid Descriptors + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform.MakeCDescriptor_M_N(), 1, 1); + + return make_tuple(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock); + } + + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + + static constexpr bool ALdsScalarLoadToVgpr = false; + static constexpr bool BLdsScalarLoadToVgpr = true; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + EDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeType, + BComputeType, + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // Note: the dummy function is used just to create the alias + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using EGridDesc_MPerBlock_NBlock_NPerBlock = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + // block-to-e-tile map for elementwise kernels + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array&, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + }(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + const auto GemmM = a_grid_desc_m_k.GetLength(I0); + const auto GemmN = b_grid_desc_n_k.GetLength(I0); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform_.MakeCDescriptor_M_N(), + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + e_grid_desc_mblock_mperblock_nblock_nperblock); + + const index_t grid_size_grp = + std::get<0>(GridwiseGemm::CalculateGridSize(GemmM, GemmN, 1, 1)); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + // const index_t GemmM = a_grid_desc_m_k.GetLength(I0); + // const index_t GemmN = b_grid_desc_n_k.GetLength(I0); + const index_t GemmK = a_grid_desc_m_k.GetLength(I1); + + // onst auto MBlock = GridwiseGemmCTranspose::CalculateMBlock(GemmM); + // onst auto NBlock = GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + index_t k_grain = split_k * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + + const bool HasMainKBlockLoop = + GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum] + [gemms_count_ % MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + } + + std::size_t GetWorkspaceSizeBytes() const { return 0; } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop_.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MPerBlock_NBlock_NPerBlock, + MaxGroupedGemmGroupsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop>; + + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + gemm_kernel_args, + gemms_count_for_set, + arg.compute_ptr_offset_of_batch_, + 1); + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ConvBwdDataSpecialization is unsupported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // Check gridwise gemm validity + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + // Create gemm arguments with dummy values to check for validity + typename GridwiseGemm::Argument gemm_arg{nullptr, // p_as_grid + nullptr, // p_bs_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + I0, // StrideAs + I0, // StrideBs + I0, // StrideE + arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3" + << (DirectLoad ? "_DirectLoad" : "") + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXdl << ", " + << NPerXdl << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle; + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 19a7536685..88c2207e09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -856,6 +857,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index a811d2f44a..172a53d652 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit k_batch_ = split_k; } } + k_batch_ = clamp_gemm_k_batch(k_batch_); if constexpr(IsTwoStageNeeded) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3eab579e7..ed0378e23f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1e23fef191..ff0616481f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 87117be4ce..bc44cf2bb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer_v2 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 0ee5ac3647..011bb068f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create initial descriptors with hack=false to check compactness const auto descs_initial = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index bfc88753a2..66fb526641 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 46a9009f83..fef81b281a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes are divisible by k_batch diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dade0515af..07c8e02514 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ? 4 / sizeof(BDataType) : BBlockTransferSrcScalarPerVector; + static constexpr bool ALdsScalarLoadToVgpr = + (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + static constexpr bool BLdsScalarLoadToVgpr = + (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + + // Note: Direct load use layout to create proper block and mmtile descriptor + // TODO: Fix and verify RC layout for not direct load (currently it returns wrong results) template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< - tensor_layout::gemm::RowMajor, - tensor_layout::gemm::ColumnMajor, + std::conditional_t, + std::conditional_t, tensor_layout::gemm::RowMajor, ADataType, BDataType, @@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - DirectLoad>; + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -625,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes match product of dimensions diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 9532f7e76a..87be350a44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -162,6 +162,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) id_off += grid_size_grp; id_local += grid_size_grp; + block_sync_lds(); } } #else diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 9978b62b17..fa33e0fdea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -136,6 +136,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) id_off += grid_size_grp; id_local += grid_size_grp; + block_sync_lds(); } } #else diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 3a3bacd945..ea5b282ed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -13,6 +13,13 @@ namespace ck { namespace tensor_operation { namespace device { +/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in +/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)). +inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept +{ + return k_batch < 1 ? index_t{1} : k_batch; +} + struct DeviceProperties { DeviceProperties() @@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index const int max_capacity = max_occupancy * device_properties.num_cu_; ck::index_t k_batch = 1; + if(grid_size <= 0) + { + return k_batch; + } const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); if(optimal_split > 1) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index a0fca218d4..c134d34161 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -66,7 +66,9 @@ template + bool DirectLoad = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct GridwiseGemm_xdl_cshuffle_conv_v3 : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return math::integer_divide_ceil(N, NPerBlock); } - template + template + __host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc) + { + + if constexpr(!DirectLoad) + { + return desc; + } + else + { + const index_t K = desc.GetLength(I0) * desc.GetLength(I2); + const index_t MN = desc.GetLength(I1); + + const auto desc_unmerged = transform_tensor_descriptor( + desc, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto desc_permuted = transform_tensor_descriptor( + desc_unmerged, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MN, K0Number)), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { - constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); - constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + if constexpr(DirectLoad && IsKContinous) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - return transform_tensor_descriptor( - TileDesc_K0_MN_K1{}, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + else + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } } template @@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor::value>( + ABlockDesc_AK0_M_AK1{}); } template @@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return MakeGemmMmaTileDescriptor::value>( + BBlockDesc_BK0_N_BK1{}); } struct Problem @@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // Disable vector load from lds to vgpr for direct load (backward weight store with continous M // or N dimension) - static constexpr bool LdsScalarLoadToVgpr = DirectLoad; - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - AccDataType, - decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), - decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + // static constexpr bool LdsScalarLoadToVgpr = DirectLoad; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - DirectLoad, - LdsScalarLoadToVgpr>())>; + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>())>; template __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) @@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x)); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { @@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + make_multi_index(static_cast(block_idx_x))); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index 3379fb2c59..74ec0af7d5 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -21,6 +21,10 @@ template struct TransformConvBwdWeightToGemm { + // Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 94eae555e9..eeef3e736e 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -31,6 +31,11 @@ template struct TransformConvBwdWeightToGemmV2 { + // Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in + // integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 8056b76af7..0cb4dbeff4 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +// Flat async load from global memory to LDS using 64-bit global addressing. +// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds +// INT32_MAX (2GB) byte offset on the SRD voffset path. +// +// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!! +// +// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3: +// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`). +// M0 does NOT appear as an operand of these instructions or of the inline +// asm below — the compiler cannot see the dependency. Caller must: +// +// 1. Initialize M0 once before the load loop: +// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));` +// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to +// enforce this. Direct VALU writes to M0 are illegal. +// +// 2. Advance M0 between successive issues: +// `m0_inc_with_memory(size_per_issue);` +// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path +// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently +// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses +// M0[15:0] as a raw byte offset). +// +// 3. Never bundle `m0_inc_with_memory` and the next call to this +// function into a single inline asm. The compiler auto-inserts a +// hazard NOP between an SALU write to M0 and the consuming +// `global_load_lds_*`; bundling bypasses that and may read stale M0. +// +// The "memory" clobber on this asm is load-bearing: it prevents the +// compiler from reordering this load across other M0-touching helpers +// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered). +// +// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950): +// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000 +// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both +// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but +// supported by the LLVM AMDGPU backend. +// +// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series). +template +CK_TILE_DEVICE void +async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) +{ +#if !defined(__gfx94__) && !defined(__gfx950__) + static_assert(always_false_v>, + "global_load_lds requires CDNA3+ (gfx940/gfx950). " + "Ensure kKVLoadMode is BUFFER_LOAD on this architecture."); +#endif + + static_assert(num_dwords == 1 || num_dwords == 4, + "global_load_lds supports num_dwords == 1 or 4 only " + "(2 dwords does not exist on any supported arch; " + "3 dwords only on CDNA4 and unused in FMHA pipeline)"); + +// Inline asm: only the global address is an explicit operand. The LDS +// destination is implicit via M0 (see contract above). `"=r"(smem)` is a +// SSA scheduling anchor only — `smem` is NOT written by this asm; the +// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`. +#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); \ + else \ + asm volatile(instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); + + if constexpr(num_dwords == 1) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4"); + } +#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR +} + template CK_TILE_DEVICE thread_buffer diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index aa29345892..45131abb97 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -45,9 +45,29 @@ template > + typename YsGatherDims = sequence<0>, + bool kUseGlobalLoad_ = false> struct tile_scatter_gather { + static constexpr bool kUseGlobalLoad = kUseGlobalLoad_; + +#if !defined(__gfx94__) && !defined(__gfx950__) + // global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950). + // On other architectures, kUseGlobalLoad must be false. + static_assert(!kUseGlobalLoad_, + "kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " + "This kernel should not be instantiated on this architecture."); +#endif + + // Empty placeholder used by the SRD instantiation so physical_pages_ and + // page_stride_elements_ occupy zero bytes there (combined with + // [[no_unique_address]] on the member declarations). Access sites are all + // inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD + // mode, so no caller needs to change. + struct gl_field_empty_t + { + }; + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -233,15 +253,22 @@ struct tile_scatter_gather const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, const PageIdxArray& page_idx, - const ValidArray& valids) + const ValidArray& valids, + index_t page_stride_elements = 0) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + physical_pages_{}, + page_stride_elements_{}, valids_{valids}, pre_computed_coords_{} { + if constexpr(kUseGlobalLoad_) + { + page_stride_elements_ = page_stride_elements; + } #if 0 // debug // TODO: this use more register for FA, but less register for GEMM // need investigation @@ -357,6 +384,34 @@ struct tile_scatter_gather bottom_tensor_view_.buf_.p_data_ = data; } + // Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for + // SRD num_records control. Use to set max range when SRD is rebased per-tile + // (page_size >= kN0 path): each rebased SRD only needs to cover one page; without + // this the SRD claims validity for memory beyond the allocated buffer, which can + // fault on gfx950 page-table validation. + // + // Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element + // count and is divided by PackedSize before being stored. For PackedSize=1 + // (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4) + // skipping it would over-report num_records by 2x and silently mask OOB on SRD + // reads. batch_prefill currently does not exercise the packed-type path, but this + // setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must + // honor the same invariant the ctor enforces. + CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size) + { + // Hint the optimizer that size is positive without inserting a runtime + // branch. Using assert() here corrupted gfx950 batch_prefill + // output: the __assert_fail handler's SGPR pressure forced the K-SRD + // register window to be reused as scratch and scattered the SRD writes + // across two conditional branches, which gfx950's packed + // buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it + // via per-tile single-dword loads). __builtin_assume is hint-only — + // no branch, no scratch SGPRs, no codegen impact. + __builtin_assume(size > 0); + using BufType = remove_cvref_t; + bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize; + } + // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] template @@ -458,7 +513,21 @@ struct tile_scatter_gather // read from bottom tensor const vector_t vec_value = [&]() { - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: 64-bit typed pointer arithmetic + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + vector_t v; + __builtin_memcpy(&v, addr, sizeof(vector_t)); + return v; + } + else if constexpr(std::is_same_v) { return get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, @@ -680,7 +749,23 @@ struct tile_scatter_gather const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: global_load_lds with 64-bit address + constexpr index_t vector_size = + sizeof(vector_t) / sizeof(uint32_t); // dwords per vector + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + // global_load_lds takes a byte address; addr (const DataType*) + // converts implicitly to const void*, no explicit cast needed. + async_global_load_lds_dwordxn(smem, addr, pre_nop_); + } + else if constexpr(std::is_same_v) { get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); @@ -1046,6 +1131,13 @@ struct tile_scatter_gather CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } + CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages) + { + static_assert(kUseGlobalLoad_, + "global-load mode only; physical_pages_ is unused in SRD mode."); + physical_pages_ = pages; + } + CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) { if constexpr(std::is_same_v == false) @@ -1139,7 +1231,29 @@ struct tile_scatter_gather // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] TileDstr tile_dstr_; + // Scatter/gather offsets for each element, set by update_page_idx(). + // SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord). + // page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base) + // page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset) + // Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only. + // Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord PageIdxArray page_idx_; + + // Physical page indices for global load mode (kUseGlobalLoad=true only). + // Maps each gather element to its physical page in a paged memory pool. + // Updated via update_physical_pages() before each load call. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + physical_pages_; + + // Page stride in elements for global load mode (kUseGlobalLoad=true only). + // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. + // Set at construction time via the make_tile_scatter_gather overload that + // takes bool_constant; immutable thereafter. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + page_stride_elements_; + ValidArray valids_; // this contains: @@ -1178,7 +1292,8 @@ template + index_t... YsGatherDims, + bool UseGlobalLoad = false> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const StaticPageIndexArray_& page_idx, number, number, - sequence) + sequence, + bool_constant = {}, + index_t page_stride_elements = 0) { return tile_scatter_gather, remove_cvref_t, @@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view, std::nullptr_t, HsGatherDim, NumCoord, - sequence>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; + sequence, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; } -// Legacy overload (compatible with original API) +// Legacy overload (compatible with original API, kUseGlobalLoad=false) template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + bool_constant, + index_t page_stride_elements = 0) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + 0, + 1, + sequence<0>, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; +} + template ` — a value-template that is always `false` but whose +// evaluation is deferred until template instantiation. The canonical use is +// inside the `else` arm of an `if constexpr` chain or under an arch-gated +// `#if` to fire a `static_assert` ONLY when the offending instantiation is +// actually requested, e.g.: +// +// if constexpr (...) { ... } +// else { static_assert(always_false_v, "unsupported T"); } +// +// A bare `static_assert(false, ...)` would fire at template-definition +// parse time on conforming compilers, breaking the whole TU. +template +inline constexpr bool always_false_v = false; + // remove_cvref_t template using remove_reference_t = typename std::remove_reference::type; diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp index baec4b45e8..32745ee424 100644 --- a/include/ck_tile/host/rotating_buffers.hpp +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/hip_check_error.hpp" #include +#include namespace ck_tile { diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 8a5d77bf46..cf651312d9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" @@ -55,6 +56,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp new file mode 100644 index 0000000000..826cd106f1 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines. +// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool) +// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache) +enum class BlockAttentionKVCacheLoadModeEnum +{ + BUFFER_LOAD = 0, + GLOBAL_LOAD_LDS = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index b04205f2c2..b7dcdb3648 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -32,6 +32,83 @@ namespace ck_tile { +namespace detail { + +// A helper struct for detecting n0loop +template +struct has_n0loop_flag : std::false_type +{ +}; + +template +struct has_n0loop_flag< + T, + std::enable_if_t && T::kUseN0Loop>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag::value; + +// A helper struct for detecting ignore_fast_exp2 flag +template +struct has_ignore_fast_exp2_flag : std::false_type +{ +}; + +// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2; +// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface +// is able to avoid passing an in-correct scale_s parameter to the kernel layer +template +struct has_ignore_fast_exp2_flag< + T, + std::enable_if_t && + T::kIgnoreFastExp2>> : std::true_type +{ +}; + +template +static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag::value; + +// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of +// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256 +// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline +template +struct has_naive_hdim_load_flag : std::false_type +{ +}; + +template +struct has_naive_hdim_load_flag< + T, + std::enable_if_t && + T::kIsNaiveHDimLoad>> : std::true_type +{ +}; + +template +static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag::value; + +// A helper struct for detecting kUseTrLoad +template +struct has_use_trload_flag : std::false_type +{ +}; + +template +struct has_use_trload_flag< + T, + std::enable_if_t && T::kUseTrLoad>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_using_trload_v = has_use_trload_flag::value; + +} // namespace detail + template struct FmhaFwdKernel { @@ -77,13 +154,14 @@ struct FmhaFwdKernel static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + static constexpr bool kUseTrLoad = detail::is_using_trload_v; - static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; #if defined(__gfx950__) static constexpr bool kIsAvailable = true; #else static constexpr bool kIsAvailable = !kUseTrLoad; #endif + static constexpr std::string_view kPipelineName = FmhaPipeline::name; template // to avoid duplicated base class prblem, introduce an template @@ -444,7 +522,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -897,7 +977,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -1039,6 +1121,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, const void* block_scale_seqstart_q_ptr, const void* block_scale_seqstart_k_ptr, + const void* seqstart_v_scale_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1097,6 +1180,7 @@ struct FmhaFwdKernel seqlen_k_ptr, block_scale_seqstart_q_ptr, block_scale_seqstart_k_ptr, + seqstart_v_scale_ptr, hdim_q, hdim_v, num_head_q, @@ -1158,6 +1242,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, const void* block_scale_seqstart_q_ptr, const void* block_scale_seqstart_k_ptr, + const void* seqstart_v_scale_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1216,6 +1301,7 @@ struct FmhaFwdKernel seqlen_k_ptr, block_scale_seqstart_q_ptr, block_scale_seqstart_k_ptr, + seqstart_v_scale_ptr, hdim_q, hdim_v, num_head_q, @@ -1602,6 +1688,10 @@ struct FmhaFwdKernel static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; + constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v + ? FmhaPipeline::kQKHeaddim + : FmhaPipeline::kSubQKHeaddim; + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( @@ -1612,10 +1702,10 @@ struct FmhaFwdKernel number<1>{}); if constexpr(FmhaPipeline::kQLoadOnce) { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); } else { @@ -1634,10 +1724,21 @@ struct FmhaFwdKernel number<1>{}); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + + if constexpr(detail::is_n0loop_pipeline_v) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) @@ -1649,18 +1750,29 @@ struct FmhaFwdKernel number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(!kUseTrLoad) + { + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }; } else { @@ -1683,17 +1795,28 @@ struct FmhaFwdKernel q_dram, [&]() { if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); + return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(), {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto k_dram_window = [&]() { + if constexpr(detail::is_n0loop_pipeline_v) + { + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + else + { + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + }(); auto v_dram_window = make_tile_window( v_dram, @@ -1843,7 +1966,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { @@ -2826,7 +2952,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 4f2d3d58c2..8aa6d17dc3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" @@ -134,7 +135,8 @@ template + index_t kVectorSize, + bool kUseGlobalLoad_ = false> CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, @@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - if constexpr(kIsKcache) - { - // K cache: per-token lookup - // Each token may be on a different page, so we use physical_pages[k0] for each. - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + // Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_): + // + // Case 1: kPageBlockSize >= kN0 + // SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller). + // Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident). + // This function writes within-page offset only. + // + // Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_ + // SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full + // 64-bit address is computed by tile_scatter_gather::load() in + // include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ + + // page_stride_elements_. This function writes within-page offset only. + // + // Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true) + // SRD base is the entire KV buffer; the only place to encode page identity + // is the voffset itself. This function writes the FULL offset: + // page * stride_page_block + within_page + // Limited to <2GB total KV bytes by 32-bit voffset hardware width. + // + // Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_ + // Not emitted by codegen. Backstop static_assert in + // BlockFmhaBatchPrefillPipelineQRKSVSAsync. + constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_; - if constexpr(kPageBlockSize >= kN0) + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + // Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT) + const index_t within_page = [&]() { + if constexpr(!kIsKcache && kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - kv_offset_vec[k0] = token_idx_in_page * stride_token; + return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); } else { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = - physical_page * stride_page_block + token_idx_in_page * stride_token; + return token_idx_in_page * stride_token; } - }); - } - else // V cache - { - // V cache: use physical_pages[k0] for each token - // physical_pages was already populated correctly by load_physical_pages(), handling: - // - page_size=1: page_idx maps token_idx -> physical_page directly - // - V tile crosses pages: per-token page lookup - // - V tile in single page: lane0 lookup with broadcast to all lanes - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + }(); - if constexpr(kPageBlockSize >= kN0) - { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = token_offset; - } - else - { - kv_offset_vec[k0] = token_idx_in_page * stride_token; - } - } - else - { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base_offset = - static_cast(physical_page) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else - { - kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token; - } - } - }); - } + // SRD + page_size < kN0: add page base to form complete voffset for buffer_load. + // + // 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF + // microcode format), so this branch is only reachable when total KV bytes fit in + // INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit + // global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling + // because the hardware truncates voffset regardless. + if constexpr(kNeedFullOffset) + { + kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page; + } + else + { + kv_offset_vec[k0] = within_page; + } + }); } // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) @@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; static constexpr index_t kVectorSize = Problem::kVectorSize; - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + // Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V + // tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD + // buffer_load_*. The enum is named at the trait/Problem level; internally we + // derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits + // GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop. + static constexpr auto kKVLoadMode = Problem::kKVLoadMode; + static constexpr bool kUseGlobalLoad = + (kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS); + static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0), + "GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; " + "codegen should not emit this instantiation otherwise."); + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), k_dist, - k_offsets); // K DRAM tile window for + k_offsets, + bool_constant{}, + page_stride_k); + if constexpr(kUseGlobalLoad) + { + k_dram_window.update_physical_pages(k_physical_pages); + } k_dram_window.init_raw(); - // SRD rebasing: move the buffer descriptor base pointer to each page's start address - // using 48-bit pointer arithmetic, so voffset only needs the small within-page offset. - // Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page). + // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_k; window.set_bottom_tensor_view_data_ptr(page_ptr); + // Limit SRD num_records to one page worth of elements. + // Without this, the SRD claims validity for [page_ptr, page_ptr + + // full_buffer_size), which extends far beyond the allocated buffer when rebased to + // high pages. On gfx950, the hardware may validate the full SRD range against page + // table permissions, causing faults on freed/protected memory beyond the buffer. + window.set_bottom_tensor_view_buffer_size(page_stride_k); window.init_raw(); } }; + // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_v; window.set_bottom_tensor_view_data_ptr(page_ptr); + window.set_bottom_tensor_view_buffer_size(page_stride_v); window.init_raw(); } }; - // Initial K SRD rebase + // Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead) rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); constexpr auto k_oob_ck = bool_constant{}; @@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>(v_physical_pages_k2, - stride_v, - page_stride_v, - v_coord, - v_offsets_k2, - current_seq_k); + kVectorSize, + kUseGlobalLoad>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; @@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } + + // v_offsets semantics — see the four-case addressing-strategy block above + // kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda: + // Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD. + // Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed + // by tile_scatter_gather::load() from + // physical_pages_. + // Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset): + // FULL offset (page * stride + within), + // carried in the 32-bit voffset (<2GB cap). }; // Prefetch V physical pages early to hide buffer load latency @@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_offsets, number<1>{}, // HsGatherDim number<1>{}, // NumCoord - VPageIndexYDims); + VPageIndexYDims, + bool_constant{}, + page_stride_v); + if constexpr(kUseGlobalLoad) + { + v_dram_window.update_physical_pages(v_physical_pages); + } - // Initial V SRD rebase + // Initial V SRD rebase. Single source of truth: rebase_v_window's own + // `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3. + // Do not re-add an outer guard here — it would duplicate the inner check + // and drift if the lambda's gating condition ever changes. rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + // Save the *current* tile's V physical pages into v_dram_window before + // prefetch_v_physical_pages overwrites the v_physical_pages buffer with the + // *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read + // physical_pages_ from the window. Encapsulating the save+prefetch pair + // here makes the ordering invariant unmissable when a fourth prefetch site + // is added later. + auto save_and_prefetch_v_pages = [&](auto k_loop_start) { + if constexpr(kUseGlobalLoad) + v_dram_window.update_physical_pages(v_physical_pages); + prefetch_v_physical_pages(k_loop_start); + }; + // prefetch K tile async_load_tile_raw( k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); @@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } // Prefetch V physical pages early - overlaps with GEMM0 computation - prefetch_v_physical_pages(number{}); + save_and_prefetch_v_pages(number{}); // STAGE 1, QK gemm clear_tile(s_acc); // initialize C @@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - prefetch_v_physical_pages(number<2 * kK1>{}); + save_and_prefetch_v_pages(number<2 * kK1>{}); } auto m_local = block_tile_reduce( @@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); @@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); @@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{}); } block_sync_lds(); @@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); + if constexpr(kUseGlobalLoad) + k_dram_window.update_physical_pages(k_physical_pages); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); // After sink→window transition (i_total_loops == num_sink_loop), V window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9..a8a8f96d3b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -9,6 +9,52 @@ namespace ck_tile { +namespace detail { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() +{ + if constexpr(std::is_same_v || std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 6 == 0) + // return 6; + if constexpr(ElemPerThread % 8 == 0) + return 8; + else if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else if constexpr(std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 3 == 0) + // return 3; + if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else + return 1; +}; + +template +CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() +{ + constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize; + + return GetMaxVectorSize(); +} + +} // namespace detail + template 2GB pools via + // 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the + // <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's + // existing TwoGB convention. + static constexpr auto kKVLoadMode = Traits_::kKVLoadMode; + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 8114bb96c4..607ee70020 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -20,7 +20,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using KDataType = remove_cvref_t; using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; using RandValOutputDataType = remove_cvref_t; using LSEDataType = remove_cvref_t; @@ -34,12 +34,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; static_assert(kQLoadOnce == Policy::QLoadOnce); + static_assert(!Problem::kUseTrLoad, "This pipeline does not use trload!"); + static_assert(sizeof(KDataType) == sizeof(VDataType) && + alignof(KDataType) == alignof(VDataType), + "K and V share the same LDS region; their element types must have identical " + "size and alignment."); + + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr bool kIsNaiveHDimLoad = true; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = + BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0 static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; @@ -47,35 +57,33 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + // since this pipeline is only used by the inference path of xformers, the Dropout function is + // not well tested with the pipeline, so here we have Dropout disabled + static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!"); + // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kAlignmentRandVal = - kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) @@ -135,9 +143,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch typename AttentionVariantParams, typename BlockIndices> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, @@ -158,8 +166,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; static_assert( std::is_same_v> && @@ -168,8 +177,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -177,24 +186,51 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch "wrong!"); constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; - constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; - static_assert(2 <= k0_loops); - static_assert(2 <= k1_loops); + + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); constexpr bool kPreloadWholeNextIterationK = Policy::template IsPreloadWholeNextIterationK(); - constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); - constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); - constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - static_assert(NumKLdsBuffers >= 2); + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQRegTileDistribution()); @@ -202,32 +238,38 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + return o_acc; + }; auto k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); using k_tile_type = decltype(load_tile(k_dram_window)); + // only prefetch two k tiles to save vgprs consumption auto k_tiles = [&]() { if constexpr(kPreloadWholeNextIterationK) - return statically_indexed_array{}; + return statically_indexed_array{}; else return statically_indexed_array{}; }(); k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); auto q_tile = load_tile(q_dram_window); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); + + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); @@ -236,612 +278,461 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_windows; + statically_indexed_array k_lds_windows; - static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_windows[i_buf] = get_slice_tile( - k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); // V tile in LDS auto v_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetExclusiveKLdsBytes()), + reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - using v_tile_type = decltype(load_tile(v_dram_window)); - - statically_indexed_array v_tiles; - using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array v_lds_windows; + statically_indexed_array v_lds_windows; - static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { v_lds_windows[i_buf] = get_slice_tile( v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); }); - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {0, seqlen_k_start}, + Policy::template MakeVDramTileDistribution()); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK) - { - if(num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + // assuming no random values need be saved, this is true when the pipeline is called from + // xformers, since we have a separate kernel to generated random values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); q_tile = tile_elementwise_in(q_element_func, q_tile); - index_t i_total_loops = 0; + auto seqlen_k_curr = seqlen_k_start; + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; do { - if constexpr(kPreloadWholeNextIterationK) + // STAGE 1, Gemm_0 ( S = Q@K ) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 { - if(i_total_loops == 0) // executed by fist iteration + if(seqlen_k_curr == seqlen_k_start) // at first iteration { - if(num_total_loop > 1) // there are multiple iterations + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; - if constexpr(i_k0 == 0) - clear_tile(s_acc); + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + // prefetch all k_tiles for next iteration + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }); + }; block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - - // prefetch all k_tiles for next iteration - static_for<0, k0_loops, 1>{}([&](auto i_k0) { - k_tiles[number{}] = load_tile(k_dram_window); - - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - }); - - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); - - block_sync_lds(); - // execute last unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); } - else // there is only single iteration + else // the iteration is also the last iteration { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; - if constexpr(i_k0 == 0) - clear_tile(s_acc); + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - - // move_tile_window(k_dram_window, {0, -k0_loops * kK0}); - } + }; } - else // executed by intermediate and last iteration + else // at intermediate and last iteration { - if(i_total_loops < num_total_loop - 1) // intermediate iteration + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - store_tile(k_lds_windows[I1], - tile_elementwise_in(k_element_func, k_tiles[I1])); - - move_tile_window(k_dram_window, {kN0, 0}); - - // prefetch first k_tile for next iteration - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - - k_tiles[I1] = load_tile(k_dram_window); - if constexpr(1 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, kK0>{}, sequence{}), - k_lds_windows[I1]); - - // during the gemm-loop, also prefetch other k_tiles for next iteration - static_for<2, k0_loops, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - k_tiles[number{}]); - - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); } else // last iteration { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - static_for<1, k0_loops, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); }; - }; + } } - else // only preload one unroll of K for next iteration + else // only preload one unroll of K for next iteration, used when kM0=128 { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - if constexpr(i_k0 == 0) - clear_tile(s_acc); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[I0]), + partition_index); - if constexpr(i_k0 < k0_loops - 1) + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_n0 < n0_loops - 1) + { k_tiles[I0] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); + } - store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - }; - - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); const auto bias_tile = load_tile(bias_dram_window); // load bias tile - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); }, - s_acc, + pcomp_tile, bias_tile); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); }); }); } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); + __builtin_amdgcn_sched_barrier(0x00000001); - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + const auto m_old = m; - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } -#else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); -#endif - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - }(); -#else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); -#endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); - dropout.template Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); - } - - __builtin_amdgcn_sched_barrier(0x7f); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - - store_tile( - v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch - } + block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max); + block_tile_reduce_sync(m, f_max, bool_constant{}); __builtin_amdgcn_sched_barrier(0); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + auto v_shuffled_tile = make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution()); + shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0])); - if constexpr(!kPreloadWholeNextIterationK) + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { - if(i_total_loops < num_total_loop - 1) - { - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - }; - - __builtin_amdgcn_sched_barrier(0); - } - - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - v_tiles[I0] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[I0])); - } - - move_tile_window(v_dram_window, {0, kK1}); - }); - } - else // NumVLdsBuffers == 3 or 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < k1_loops - NumPrefetchV) - v_tiles[number{}] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile( - v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); - } - - if constexpr(i_k1 < k1_loops - NumPrefetchV) - move_tile_window(v_dram_window, {0, kK1}); - }); - } - } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); - - if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) - { - __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); }; - } while(++i_total_loops < num_total_loop); + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + shuffle_tile(v_shuffled_tile, + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_shuffled_tile, + partition_index); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); // store lse if constexpr(kStoreLSE) @@ -851,19 +742,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } -#else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); }); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); @@ -874,17 +753,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); }); @@ -916,8 +791,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { + ignore = sink_v; + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 3f015a1c1a..e5e9e2333a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -4,17 +4,20 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp" namespace ck_tile { struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy { - static constexpr index_t NumPrefetchV = 2; + static constexpr bool QLoadOnce = true; // needed by the kernel + static constexpr bool AsyncCopy = false; // needed by the kernel template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK() @@ -23,30 +26,38 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy }; template - CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV() { - return 2; - } + constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK0; + constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; - template - CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - - constexpr index_t k1_loops = kN0 / kK1; - - return min(NumPrefetchV, k1_loops); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() - { - return 2; + if constexpr(Problem::kUseTrLoad) + { + // kM0 is 64, kN0 is 128, prefetch all k_tiles + if constexpr(IsPreloadWholeNextIterationK()) + { + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 2; + return 2; + } + else // kM0 is 128, kN0 is 64, prefetch one k_tile + { + // kN0 == 64, try to prefetch more v_tiles + return 2; + }; + } + else + { + return 2; + }; }; + template + CK_TILE_HOST_DEVICE static constexpr auto GetNumKVLdsBuffers() + { + return 4; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -57,195 +68,537 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeCBlockTile() + .get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { - using KDataType = remove_cvref_t; - return 8 / sizeof(KDataType); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + return detail:: + GetDramTileAccessMaxVectorSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + // special consideration when shuffling is required before storing V to LDS + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t kMaxVecLoad = detail:: + GetDramTileAccessMaxVectorSize(); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + + // try to avoid writing sub-dword to LDS due to poor performance + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); + + return kVecLoad; + } + else + { + return detail:: + GetDramTileAccessMaxVectorSize(); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + // for hdim96 and hdim160 + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + return kKPerBlock * kNPerBlock; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + return kKPerBlock * kNPerBlock; + } + else + { + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + + return N0 * (N1 * kKPerBlock + kKPack); + } + else + { + return kNPerBlock * kKPerBlock; + }; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + return max(GetKSingleSmemElementSpaceSize(), + GetVSingleSmemElementSpaceSize()); + }; + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { - constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - static_assert(kKVector % kKPack == 0); + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - return k_lds_block_desc; + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + using KDataType = remove_cvref_t; + + constexpr index_t DataTypeSize = sizeof(KDataType); + +#ifdef __gfx950__ + // 256 contiguous bytes mapped to 64 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (64 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (64 * 4 / kKPerBlock / DataTypeSize); +#else + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); +#endif + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{})); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_k0_nldslayer_n_k1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr index_t KSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + }; } template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - using KDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t kKVector = GetAlignmentK(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() - { - using VDataType = remove_cvref_t; - - constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); - - constexpr index_t Banks = get_n_lds_banks(); - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - - constexpr index_t VSingleSmemElementSpaceSize = - (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple( - number{}, number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return v_lds_block_desc; - } - - template - CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() - { - using VLayout = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - if constexpr(std::is_same_v) + if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - constexpr index_t K3 = ElemPerThread / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = GetAlignmentV(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - static_assert(N0 != 0); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; + + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); + + static_assert(kKPack >= K2, "Check failed!"); + + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else + { + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + + constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr auto v_lds_block_desc_naive = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); } template @@ -257,113 +610,167 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy typename Problem::SaccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, + Problem::BlockFmhaShape::kK0, + Problem::BlockFmhaShape::kQKHeaddim>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(get_warp_size() == 64 && - std::is_same_v && - std::is_same_v && + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v) && std::is_same_v) { - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + +#ifdef __gfx950__ + static_assert((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16), + "Not supported WarpGemm sizes!"); +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); +#endif - // TODO: hard coded here. Otherwise, it produces incorrect results - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } - else - { - constexpr bool SwizzleA = - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true, // TransposeC - SwizzleA>{}; + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); } }(); + using WarpGemm = remove_cvref_t; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy; + WarpGemm>; if constexpr(1 < Problem::kNumGemm0Warps) - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2PrefetchK{}; else return BlockGemmARegBSmemCRegOneWarpV1{}; } - // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first - // k_lds bufffer template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - constexpr index_t single_k_lds_buffer_size = - GetSmemSizeK() / GetNumKLdsBuffers(); - constexpr index_t single_v_lds_buffer_size = - GetSmemSizeV() / GetNumVLdsBuffers(); + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; - if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) - return 0; + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v) && + std::is_same_v) + { + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}); + + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); + + if constexpr((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16)) + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>{}; + else + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm1Warps) + { + if constexpr(!Problem::kUseTrLoad) + return BlockGemmARegBSmemCRegV2PrefetchN{}; + else + return BlockGemmARegBSmemTrLoadCRegV2PrefetchN{}; + } else - return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); - constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); - - constexpr index_t last_v_lds_buffer_offset = - MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * - ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); - - constexpr index_t first_k_lds_buffer_size = - MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * - sizeof(typename Problem::KDataType); - - return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < - first_k_lds_buffer_size; - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::KDataType); + return BlockGemmARegBSmemCRegOneWarpV1{}; } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { - return MakeVLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::VDataType); - } + constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); + + return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * + max(sizeof(typename Problem::KDataType), sizeof(typename Problem::VDataType)); + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + static_assert(!Problem::kHasDropout, + "BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy does not " + "account for dropout LDS scratch space. Either use a policy " + "that implements dropout shared-memory sizing or disable dropout " + "for this pipeline."); + return 0; + }; template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // assume V can reuse the other shared memory by K except the first - // assume Dropout can reuse the shared memory by V - return GetExclusiveKLdsBytes() + - max(GetSmemSizeK() - GetExclusiveKLdsBytes(), - max(GetSmemSizeV(), GetSmemSizeDropout(0))); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp new file mode 100644 index 0000000000..95f68623fa --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -0,0 +1,861 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; + static_assert(kQLoadOnce == Policy::QLoadOnce); + static_assert(sizeof(KDataType) == sizeof(VDataType) && + alignof(KDataType) == alignof(VDataType), + "K and V share the same LDS region; their element types must have identical " + "size and alignment."); + + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr bool kIsNaiveHDimLoad = true; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = + BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0 + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + + static_assert(Problem::kUseTrLoad == true, "Check failed!"); + + static constexpr bool kUseTrLoad = true; + + // since this pipeline is only used by the inference path of xformers, the Dropout function is + // not well tested with the pipeline, so here we have Dropout disabled + static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async_whole_k_prefetch_trload"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& /* unused */, + const AttentionVariantParams& /* unused */, + const BlockIndices& /* unused */, + void* smem_ptr, + DropoutType& dropout) const + { + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr index_t n0_loops = kN0 / kN0Sub; + constexpr index_t k1_loops = kN0 / kK1; + + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); + + constexpr bool kPreloadWholeNextIterationK = + Policy::template IsPreloadWholeNextIterationK(); + + // This path prefetches two k_tiles for next iteration, so it has the opportunity to + // prefetch two v_tiles during Gemm0 + if constexpr(!kPreloadWholeNextIterationK) + { + static_assert(NumPrefetchV >= 2); + }; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + return o_acc; + }; + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + auto q_tile = load_tile(q_dram_window); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + auto k_tiles = [&]() { + if constexpr(kPreloadWholeNextIterationK) + return statically_indexed_array{}; + else + return statically_indexed_array{}; + }(); + + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + if constexpr(!kPreloadWholeNextIterationK) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + using k_lds_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); + + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); + + // assuming no random values need be saved, this is true when the pipeline is called from + // xformers, since we have a separate kernel to generated random values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 + { + if(seqlen_k_curr == seqlen_k_start) // at first iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + + // prefetch all k_tiles for next iteration + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // the iteration is also the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + else // at intermediate and last iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + } + else // only preload one unroll of K for next iteration, used when kM0=128 + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_n0 < n0_loops - 2) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 >= n0_loops - 2) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + + __builtin_amdgcn_sched_barrier(0x000000001); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + + tile_elementwise_inout( + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); + }); + }); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + move_tile_window(bias_dram_window, {0, kN0}); + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + __builtin_amdgcn_sched_barrier(0); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kPreloadWholeNextIterationK) + { + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + } + else + { + static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + }; + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV + 1) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]), + partition_index); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout, + const float sink_v) const + { + ignore = sink_v; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index d2d8bb2c7e..9fc3652f51 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { + ignore = sink_v; + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 71da3767b0..f217f57bad 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -57,7 +57,7 @@ struct TileFmhaShape static constexpr index_t kQKHeaddim = BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!"); static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 7df39c3d11..e7370cdb65 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" @@ -58,7 +59,9 @@ template + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits= 0) { if constexpr(Problem::LocalToken) { diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp new file mode 100644 index 0000000000..f84d232196 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchK +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + // hot loop: + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(nIter)(I0) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(I0), + {nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); + b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + if constexpr(kIter < KIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(nIter)(number{}) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(number{}), + {nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); + b_warp_tensors[number{}] = + load_tile(b_warp_windows(nIter)(number{})); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + if constexpr(kIter == 0) + { + // warp GEMM + c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + } + else + { + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + }; + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..51f59e16c0 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile(b_warp_windows(number{})(kIter)); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..c731539134 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // construct from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + constexpr auto b_warp_dstr_encode = + typename InputTileDistributionTraits::TransposedDstrEncode{}; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN}, + make_static_tile_distribution(b_warp_dstr_encode)); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0))); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {kIter * KPerBlockPerIter, 0 * NPerBlockPerIter}); + b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter)); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile_transpose(b_warp_windows(number{})(kIter)); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp new file mode 100644 index 0000000000..e2b0cb74ba --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true> + + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index f784b6ea51..09301474f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( op_ptrs); @@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 7c61f3ee66..8dae166dd1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( #endif #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 19e27cf173..7f2363affd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -32,6 +32,8 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4d434cc390 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..9d1fb4b93a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck