From a29eac5da06cce8b5994f9b98dd4eef9155dbfcd Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 25 Aug 2025 20:55:12 +0800 Subject: [PATCH] [CK_TILE] FMHA avoid unnecessary vmcnt0 (#2715) * FMHA avoid unnecessary vmcnt0 Squashed commit of the following: commit 61f5a8d4ef2cb74c0bd4caac359708d6fdb50de7 Author: aska-0096 Date: Fri Aug 22 03:15:51 2025 +0000 merge develop and solve conflicts commit ed7d18e306e16e6f39170a8ae4202d5df7b4045c Merge: 2dac61a4f 13a6816fb Author: aska-0096 Date: Fri Aug 22 03:15:21 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into vmcnt0issue commit 2dac61a4f8d28fde9c466ae3ce56435fb679a140 Author: Ding, Yi Date: Tue Aug 19 02:17:43 2025 +0000 update bwd commit 281bfa9cc94eb08effdcdb6e8028bccc1d166682 Author: Kevin Choi Date: Mon Aug 18 19:36:38 2025 +0000 add restrict to applicable functions commit 45534dee5bcbe532da46fc5cd6601cde10d84387 Author: Ding, Yi Date: Mon Aug 18 02:07:03 2025 +0000 bwd filter commit 7abd7b372b82cba94a457238b6b4a81d093e7280 Author: Kevin Choi Date: Sat Aug 16 08:15:23 2025 +0000 remove noinline attr as it causes a lot more s_waitcnt's commit 89c29746a09255c1d26038171157e91d1b68d14a Author: Kevin Choi Date: Thu Aug 14 12:11:17 2025 +0000 remove innerloop, move restrict parameters to mainloop and add noinline attribute. commit 6f61b3a5c80011411aa3aebf7983602f7c117566 Author: Kevin Choi Date: Thu Aug 14 07:06:51 2025 +0000 Create inner lambda with restrict parameters, add restrict to some parameters commit 4e17551191980ea7a7e71e9798946cf1dc9f1a1a Author: aska-0096 Date: Thu Aug 14 03:43:54 2025 +0000 save for debug commit 5f2c3cfa86c6951208a1cc227fa556704a885a88 Merge: 25f067b4f 165a2723c Author: aska-0096 Date: Wed Aug 13 02:15:22 2025 +0000 Merge branch 'wip-async-tr-fa' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 25f067b4f09d6909a05e252c7621124046dfda57 Merge: 447c1c5d6 2ad2f97b7 Author: aska-0096 Date: Wed Aug 13 02:14:26 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 165a2723c557420b48891cc1ce3434e3675aef5d Merge: 447c1c5d6 4491739ab Author: asleepzzz Date: Wed Aug 13 00:34:11 2025 +0800 Merge branch 'develop' into wip-async-tr-fa commit 447c1c5d6ef0474f9a54c06eea68d65b0346f9b6 Author: aska-0096 Date: Tue Aug 12 14:25:50 2025 +0000 refactor blockgemm change, isolate to v2; commit 8f67083511ff77d31c880f4427d3bdf53a179568 Author: aska-0096 Date: Tue Aug 12 09:26:13 2025 +0000 clang format commit 3f28caa88b9ac9d84029948a7bacf1175cc5a965 Merge: c84662c34 245071bcf Author: aska-0096 Date: Tue Aug 12 09:04:41 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit c84662c345755ec5f3d524fdde4aa951c8f86298 Author: aska-0096 Date: Tue Aug 12 08:46:06 2025 +0000 Fix the bug commit e0647ffa5646f8132529b152af02750c4010013d Author: aska-0096 Date: Tue Aug 12 04:02:41 2025 +0000 fix conflict. disable all v-col instance for fmha fwd commit 781f98236c376f57591a6d481cc2ee04b36a148b Merge: 241f3d7dc 6e03d9607 Author: aska-0096 Date: Tue Aug 12 03:52:34 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 241f3d7dc35b2d1cca4eca8ba714581e84f5725e Author: aska-0096 Date: Tue Aug 12 01:53:31 2025 +0000 clang format commit 8ee83f1c492ae9600a947c4cfe5f7cd25156138f Merge: 1a629c098 3639befe9 Author: aska-0096 Date: Tue Aug 12 01:52:52 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 1a629c09876cc05f0750db7eade1d527dc32a1d3 Merge: f65874e5b b34a029cd Author: aska-0096 Date: Mon Aug 11 15:59:40 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit f65874e5b07579d5b734b4c68877679a3ee04dac Author: aska-0096 Date: Mon Aug 11 15:37:37 2025 +0000 change the warp setting for hdim32 fmha fwd commit 7c5f5e65e97486c074ef9a138900ed9aafea547e Author: aska-0096 Date: Mon Aug 11 14:21:09 2025 +0000 tempsave, update the blocksync functions commit beb0950ad8c6b0366a77f5b82e7d5c5f8663b915 Author: aska-0096 Date: Sun Aug 10 06:00:51 2025 +0000 fix bug in pki4 commit 073db2e18af21f1ed1fb3d1f1c15830838df986f Author: aska-0096 Date: Sat Aug 9 03:25:12 2025 +0000 fix bugs in gemm commit 01f2d7bd763f64f19861b8a2a861b50bd0aed70a Author: aska-0096 Date: Fri Aug 8 18:35:53 2025 +0000 fix bug on non-gfx950 commit 9a9ca06d59cb1721b4fa70a0d3253fb6b252b37e Author: aska-0096 Date: Fri Aug 8 17:53:19 2025 +0000 fix bug commit 30de97f473685e0bd5b82f15eee2493d9a05cffd Author: aska-0096 Date: Fri Aug 8 15:42:15 2025 +0000 fix bugs commit f449cb85a3cfb27bf86525e9c11a2ecf4f7a73a7 Author: aska-0096 Date: Fri Aug 8 09:31:01 2025 +0000 fix clangformat with 18.1.3 commit e4cb185c41586d018771a5413efd909d8d53a8c5 Author: aska-0096 Date: Fri Aug 8 09:07:40 2025 +0000 remove non-necessary change commit 498f0d44cfba17287cce8d10855cce5c5de263db Author: aska-0096 Date: Fri Aug 8 09:04:02 2025 +0000 bug fix, clang format; commit 3cb648cbc4883e6889340d85f48d803a21b9c805 Author: aska-0096 Date: Fri Aug 8 08:08:03 2025 +0000 Remove unnecessary changes commit 9e7ff3b611b7933b65973907a0cae312a15d31c6 Merge: a3c1bfe6d 7f14bd1df Author: aska-0096 Date: Fri Aug 8 07:50:12 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit a3c1bfe6dd64572e4371c7b1b8b5a809aad90c71 Author: aska-0096 Date: Fri Aug 8 06:19:31 2025 +0000 remove unnecessary files; rename some files commit 6c257fa27729c005d539b5b71deeba3703031089 Author: aska-0096 Date: Fri Aug 8 05:46:18 2025 +0000 merge fa_decode pipeline into fmha_fwd api commit 26c911b4e5e43aa78fadc5b7c7880421b94d9449 Author: aska-0096 Date: Wed Aug 6 05:58:43 2025 +0000 add __restrict__ to tr load commit bbad2b979b701533b74f43452ffe0f775e019139 Author: aska-0096 Date: Tue Aug 5 07:23:51 2025 +0000 Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug commit d7fabd5f765e2a573ddbaf0857ce6f691407e562 Author: aska-0096 Date: Mon Aug 4 10:27:42 2025 +0000 Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA commit 9f2c1c5baddaa3a2aa9cd70c4a62401df3c29fd9 Author: aska-0096 Date: Mon Aug 4 10:02:17 2025 +0000 add vmcnt guard before load ktile commit f9772f8b6035bc92aa08fb4d092fc21b6b24445c Author: aska-0096 Date: Mon Aug 4 06:49:01 2025 +0000 Load Q through lds, implement xor; commit 62bb9f05177dfb8280d6c2be67a88492d6be4838 Author: aska-0096 Date: Fri Aug 1 10:44:54 2025 +0000 small refactor commit 7cb83c2ab6a87d161259eeb8d5ac3e27ce9587af Author: aska-0096 Date: Thu Jul 31 10:25:37 2025 +0000 upgrade prefill pipeline; simple iglp; consistent data produce and consume order commit 3a85dee389c424490a5101f05c3f4aa3a1ea70be Author: aska-0096 Date: Thu Jul 31 05:13:27 2025 +0000 enable larger tile size; upgrade xor pattern commit a468e59a01d6dd85c105ca30ac249491256c5915 Author: aska-0096 Date: Wed Jul 30 12:25:33 2025 +0000 remove all lds bankconflict with xor layouts commit 39ff55cdc377311112100fb24bc013adfd8960c0 Author: aska-0096 Date: Wed Jul 30 03:51:06 2025 +0000 enable prefill overload operator(). commit a7b152a788e8035c93f8e4cbf317863182665d8f Author: aska-0096 Date: Fri Jul 25 07:10:01 2025 +0000 fix the lds alignment caused performance regression commit c4e99bc8f502cd019a754cc9e0043e3d8b9d0f3e Author: aska-0096 Date: Wed Jul 23 09:05:57 2025 +0000 remove unnecessary features commit 9758750801c7fd5a80f654eb982f43b87d674fa3 Author: aska-0096 Date: Tue Jul 22 08:04:05 2025 +0000 tempsave. asynccopy+trload sanity checked commit 1c4c04d725047357224ebf8a2b94d9010a5651a6 Author: aska-0096 Date: Mon Jul 21 05:55:55 2025 +0000 tempsave, trload+asyncload done commit 75e68f91fc5a1f35cd5d96901efe15c346a1bd5c Author: aska-0096 Date: Fri Jul 18 10:04:34 2025 +0000 compile pass commit d41b5eace939909084d32281710fb81142ad5fec Merge: 3f86a81ee 33204e15f Author: aska-0096 Date: Fri Jul 18 05:17:27 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa commit 3f86a81eee75256a78df02032d50814aaa42b038 Author: aska-0096 Date: Fri Jul 18 05:16:39 2025 +0000 tempsave commit 7d43f7446a9a20773f70e08462393f6c9afb7280 Author: aska-0096 Date: Thu Jul 17 10:06:09 2025 +0000 temp save commit 727629cd9115f1be9c1800bb65a8ea84ff06c250 Merge: aa5da19c9 94bceebc9 Author: aska-0096 Date: Thu Jul 17 07:24:32 2025 +0000 Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline commit 94bceebc96ef4885e0ac861b7793e2e2897481bd Author: aska-0096 Date: Thu Jul 17 03:10:46 2025 +0000 move test_copy into test commit 8f8bfe7f33884f1588bb7aa1a1d599521f40a30e Author: aska-0096 Date: Thu Jul 17 02:41:31 2025 +0000 remove unnecessary output commit b1dbcacb1832560c6cc967a079dffce558228f0b Merge: 5b0d311e6 0eaf3325a Author: aska-0096 Date: Thu Jul 17 02:26:13 2025 +0000 Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into test_copy_fix commit 5b0d311e649257557a7014c28fcfac0c327b77b5 Author: aska-0096 Date: Thu Jul 17 02:26:10 2025 +0000 add input validation and bug fix commit 0eaf3325a8e019402ff12a2402f446f8471f584f Merge: a66e1d29a 08c5df68a Author: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed Jul 16 11:23:57 2025 -0700 Merge branch 'develop' into test_copy_fix commit a66e1d29a8cccc17cc8958d970ec7b1281ec8291 Author: aska-0096 Date: Wed Jul 16 08:55:50 2025 +0000 fix vmcnt shift commit 197bdcb4827dae6d8460ed375e6265c2c9ddaef0 Author: aska-0096 Date: Wed Jul 16 08:37:07 2025 +0000 Improve s_waitcnt_imm calculation commit 3b59e26cf8e0ba573a99a6caa0f37296b23b8bd2 Author: aska-0096 Date: Wed Jul 16 05:39:50 2025 +0000 fix the s_waitcnt_imm calculation commit 1c0870089a0e7c78ed71a278bf52d98fc780e482 Merge: d6ee05e36 d9de58c66 Author: aska-0096 Date: Wed Jul 16 03:57:57 2025 +0000 Merge branch 'develop' of https://github.com/ROCm/composable_kernel into test_copy_fix commit d6ee05e360dc8426ed2a08a8d6877ebf5cabbd32 Author: aska-0096 Date: Wed Jul 16 03:54:33 2025 +0000 Add block_sync_lds_direct_load utility commit c037a72040217471f52ee76bed9c07bf5b22aef4 Author: aska-0096 Date: Tue Jul 15 09:39:03 2025 +0000 fix async copytest bug commit aa5da19c94022449b027e7a57668f2e219f0f171 Author: aska-0096 Date: Thu Jul 10 04:29:33 2025 +0000 temp save, change all instance to 1wave commit ddd172feb9eb2cb783420a8db6f44d51b350c370 Author: aska-0096 Date: Tue Jul 8 08:37:20 2025 +0000 tempsave, fmha_decode commit fd90531f4eafdfdbf7df0f3731018fc57dcf4a33 Author: aska-0096 Date: Sat Jun 21 15:02:57 2025 +0000 temp save, waiting for debug commit 71dd31f15bca01995c8cb0be9e903103f4657181 Author: aska-0096 Date: Thu Jun 19 05:11:52 2025 +0000 save an example for __bf16 type commit cdf33e079fa7d7d5b03b06550df2356b02041d7b Author: aska-0096 Date: Wed Jun 18 07:27:24 2025 +0000 fix bwd code commit d630998dc6751f44097b1e9a239bb5063a793736 Author: aska-0096 Date: Wed Jun 18 06:37:16 2025 +0000 Fix for fwd/bwd kernel build filter commit d5ec3d0e5768aafed7f77151b2a835e87b9f95ba Author: Ding, Yi Date: Tue Aug 19 08:13:18 2025 +0000 Add restrict to avoid unnecessary vmcnt --------- Co-authored-by: aska-0096 * Add comments for c-stype cast * Better comments --------- Co-authored-by: aska-0096 [ROCm/composable_kernel commit: de61e554938265a5d17a1bba8c148457125e80cd] --- .../core/arch/amd_buffer_addressing.hpp | 39 ++--- .../arch/amd_buffer_addressing_builtins.hpp | 39 ++--- include/ck_tile/core/tensor/buffer_view.hpp | 20 +-- include/ck_tile/core/tensor/tensor_view.hpp | 6 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 8 +- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 4 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 4 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 137 +++++++++++------- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 70 +++++---- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 41 +++--- 10 files changed, 217 insertions(+), 151 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 037e86909d..7a9c017eb2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1833,14 +1833,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); +#pragma clang diagnostic pop } template & src_thread_ template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { +#define __LDS_ADDR __attribute__((address_space(3))) static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), "We need to have the compatible compiler version to build this instruction"); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + const auto in_ptr_ = (__LDS_ADDR T*)(const_cast(in_ptr)); +#pragma clang diagnostic pop if constexpr(std::is_same_v, ck_tile::half_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; - __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; - __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t> || @@ -2812,15 +2818,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) std::is_same_v, ck_tile::int8_t>) { typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; - __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else { static_assert(false, "not implemented"); } +#undef __LDS_ADDR } #endif diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index d1e4eb3da3..4013b51479 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1603,14 +1603,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, if constexpr(oob_conditional_check) v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - /*src_immediate_addr_offset*/ 0, - static_cast(coherence)); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + (as3_uint32_ptr)(smem), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); +#pragma clang diagnostic pop } template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { +#define __LDS_ADDR __attribute__((address_space(3))) static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), "We need to have the compatible compiler version to build this instruction"); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // Use C-style cast to change address space without dropping llvm noalias attribute + const auto in_ptr_ = (__LDS_ADDR T*)(const_cast(in_ptr)); +#pragma clang diagnostic pop if constexpr(std::is_same_v, ck_tile::half_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; - __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::bf16_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; - __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); } else if constexpr(std::is_same_v, ck_tile::fp8_t> || @@ -2630,15 +2636,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) std::is_same_v, ck_tile::int8_t>) { typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; - __attribute__((address_space(3))) llvm_i32x2_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>( - reinterpret_cast(in_ptr)); + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } else { static_assert(false, "not implemented"); } +#undef __LDS_ADDR } #endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ca314a6abe..d1e770ef42 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -62,12 +62,12 @@ struct buffer_view -CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size) +CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size) { return buffer_view{p, buffer_size}; } @@ -1266,7 +1266,7 @@ template , remove_cvref_t>::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto -make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) +make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value) { return buffer_view{ p, buffer_size, invalid_element_value}; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index a85dbc6d00..6fa8f898e5 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -449,7 +449,7 @@ template -CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, +CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p, const tensor_descriptor& desc) { auto buffer_view = @@ -468,7 +468,7 @@ template ::type = false> CK_TILE_HOST_DEVICE constexpr auto -make_naive_tensor_view(DataType* p, +make_naive_tensor_view(DataType* __restrict__ p, const tuple& lengths, const tuple& strides, number = number<-1>{}, @@ -491,7 +491,7 @@ template CK_TILE_HOST_DEVICE constexpr auto -make_naive_tensor_view_packed(DataType* p, +make_naive_tensor_view_packed(DataType* __restrict__ p, const tuple& lengths, number = number<-1>{}) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5e16fc563b..3f5bef366e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1115,7 +1115,8 @@ struct FmhaBwdDQDKDVKernel {i_n0, 0}); if constexpr(!kUseQrQtrDorPipeline) { - auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr, + q_dram_window, k_dram_window, v_dram_window, bias_dram_window, @@ -1131,7 +1132,6 @@ struct FmhaBwdDQDKDVKernel kargs.scale, rp_undrop, scale_rp_undrop, - smem_ptr, dropout); KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); @@ -1139,7 +1139,8 @@ struct FmhaBwdDQDKDVKernel } else { - FmhaPipeline{}(q_dram_window, + FmhaPipeline{}(smem_ptr, + q_dram_window, k_dram_window, v_dram_window, bias_dram_window, @@ -1160,7 +1161,6 @@ struct FmhaBwdDQDKDVKernel kargs.scale, rp_undrop, scale_rp_undrop, - smem_ptr, dropout); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index d36f8ad724..5e63fb714a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR typename BiasGradDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + operator()(void* smem_ptr, + const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, @@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 88fb1281aa..b883aad155 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP typename BiasGradDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + operator()(void* smem_ptr, + const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, @@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9a31498dd1..9bd78b4077 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR else return raw_lse; }; + template + CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const + { + // LDS allocation + // cast to char* to do pointer arithmetic + const auto smem_ptr_ = reinterpret_cast(smem_ptr); + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); + const auto do_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr0 = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr1 = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto lse_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE()); + const auto ds_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + return run(k_lds_ptr, + v_lds_ptr, + do_lds_ptr0, + do_lds_ptr1, + q_lds_ptr0, + q_lds_ptr1, + lse_lds_ptr, + d_lds_ptr, + ds_lds_ptr, + bias_lds_ptr, + std::forward(args)...); + } template - CK_TILE_DEVICE auto operator()( // + CK_TILE_DEVICE auto run( // + KDataType* __restrict__ k_lds_ptr, + VDataType* __restrict__ v_lds_ptr, + OGradDataType* __restrict__ do_lds_ptr0, + OGradDataType* __restrict__ do_lds_ptr1, + QDataType* __restrict__ q_lds_ptr0, + QDataType* __restrict__ q_lds_ptr1, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, + GemmDataType* __restrict__ ds_lds_ptr, + BiasDataType* __restrict__ bias_lds_ptr, const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, @@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( @@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } } - // LDS allocation - const auto smem_ptr_ = - reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic - - const auto k_lds_ptr = reinterpret_cast(smem_ptr_); - const auto v_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeK()); - - const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); - const auto do_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr0 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ()); - const auto lse_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); - const auto d_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE()); - const auto ds_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); - const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); - auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = @@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR decltype(load_tile(d_dram_window)) d_block_tile; index_t i_total_bodys = 0; - auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { - const bool is_even = (i_total_bodys % 2 == 0); - QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; - QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; - OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; - OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; - + auto main_body_impl = [&](auto is_prologue_, + auto is_epilogue_, + QDataType* const __restrict__ q_lds_ptr_curr, + QDataType* const __restrict__ q_lds_ptr_next, + OGradDataType* const __restrict__ do_lds_ptr_curr, + OGradDataType* const __restrict__ do_lds_ptr_next) mutable { constexpr bool is_prologue = is_prologue_.value; constexpr bool is_epilogue = is_epilogue_.value; static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); @@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR if constexpr(is_prologue) { + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); async_load_tile(q_lds_write_window, q_dram_window); move_tile_window(q_dram_window, {kM0, 0}); - lse_block_tile = load_tile(lse_dram_window); - move_tile_window(lse_dram_window, {kM0}); - do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); async_load_tile(do_lds_write_window, do_dram_window); move_tile_window(do_dram_window, {kM0, 0}); - - d_block_tile = load_tile(d_dram_window); - move_tile_window(d_dram_window, {kM0}); } if constexpr(is_epilogue) { @@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR constexpr auto i_j_idx = make_tuple(idx0, idx1); bool undrop_flag = p[i_j_idx] >= 0; ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag - ? (dp_acc[i_j_idx] - d[i_idx]) - : d[i_idx]); + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); }); }); @@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } move_tile_window(dq_dram_window, {kM0, 0}); } + }; + + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + const bool is_even = (i_total_bodys % 2 == 0); + const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; + main_body_impl(is_prologue_, + is_epilogue_, + q_lds_ptr_curr, + q_lds_ptr_next, + do_lds_ptr_curr, + do_lds_ptr_next); i_total_bodys += 1; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 789cfb3ea4..5adb64564d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR return raw_lse; }; + template + CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const + { + // LDS allocation + const auto smem_ptr_ = + reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic + + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr = reinterpret_cast(smem_ptr_); + const auto q_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto lse_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); + + const auto ds_lds_ptr = + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeV()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + return run(k_lds_ptr, + v_lds_ptr, + do_lds_ptr, + q_lds_ptr, + lse_lds_ptr, + d_lds_ptr, + ds_lds_ptr, + bias_lds_ptr, + std::forward(args)...); + } + template - CK_TILE_DEVICE auto operator()( // + CK_TILE_DEVICE auto run( // + KDataType* __restrict__ k_lds_ptr, + VDataType* __restrict__ v_lds_ptr, + OGradDataType* __restrict__ do_lds_ptr, + QDataType* __restrict__ q_lds_ptr, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, + GemmDataType* __restrict__ ds_lds_ptr, + BiasDataType* __restrict__ bias_lds_ptr, const QDramBlockWindowTmp& q_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp, @@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR float scale, float rp_undrop, float scale_rp_undrop, - void* smem_ptr, FmhaDropout& dropout) const { static_assert( @@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR {seqlen_kv_start, 0}, Policy::template MakeKDramTileDistribution()); - // LDS allocation - const auto smem_ptr_ = - reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic - - const auto k_lds_ptr = reinterpret_cast(smem_ptr_); - const auto v_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeK()); - - const auto do_lds_ptr = reinterpret_cast(smem_ptr_); - const auto q_lds_ptr = reinterpret_cast( // - smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto lse_lds_ptr = reinterpret_cast( // - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ()); - const auto d_lds_ptr = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); - - const auto ds_lds_ptr = - reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeV()); - const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); - auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 39d8814692..aafe481d2b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload typename LSEaccDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, @@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_barrier(0); - auto mainloop = [&](index_t cur_loop) { - const bool is_even_loop = (cur_loop % 2 == 0); - - auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) - : static_cast(smem_ptrk1); - auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) - : static_cast(smem_ptrk0); - auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) - : static_cast(smem_ptrv0); - auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) - : static_cast(smem_ptrv1); - + auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr, + KDataType* __restrict__ k_lds_read_ptr, + KDataType* __restrict__ v_lds_write_ptr, + KDataType* __restrict__ v_lds_read_ptr) { // move V tile windows block_sync_lds(); move_tile_window(v_dram_window, {kN0, 0}); @@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ }); - }; + }; // mainloop do { - mainloop(i_total_loops); + bool is_even_loop = i_total_loops % 2 == 0; + auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) + : static_cast(smem_ptrk1); + auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) + : static_cast(smem_ptrk0); + auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) + : static_cast(smem_ptrv0); + auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) + : static_cast(smem_ptrv1); + mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr); i_total_loops++; } while(i_total_loops < num_total_loop);