mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] FMHA avoid unnecessary vmcnt0 (#2715)
* FMHA avoid unnecessary vmcnt0
Squashed commit of the following:
commit 61f5a8d4ef2cb74c0bd4caac359708d6fdb50de7
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 22 03:15:51 2025 +0000
merge develop and solve conflicts
commit ed7d18e306e16e6f39170a8ae4202d5df7b4045c
Merge: 2dac61a4f 13a6816fb
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 22 03:15:21 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into vmcnt0issue
commit 2dac61a4f8d28fde9c466ae3ce56435fb679a140
Author: Ding, Yi <yi.ding@amd.com>
Date: Tue Aug 19 02:17:43 2025 +0000
update bwd
commit 281bfa9cc94eb08effdcdb6e8028bccc1d166682
Author: Kevin Choi <kevin.choi@amd.com>
Date: Mon Aug 18 19:36:38 2025 +0000
add restrict to applicable functions
commit 45534dee5bcbe532da46fc5cd6601cde10d84387
Author: Ding, Yi <yi.ding@amd.com>
Date: Mon Aug 18 02:07:03 2025 +0000
bwd filter
commit 7abd7b372b82cba94a457238b6b4a81d093e7280
Author: Kevin Choi <kevin.choi@amd.com>
Date: Sat Aug 16 08:15:23 2025 +0000
remove noinline attr as it causes a lot more s_waitcnt's
commit 89c29746a09255c1d26038171157e91d1b68d14a
Author: Kevin Choi <kevin.choi@amd.com>
Date: Thu Aug 14 12:11:17 2025 +0000
remove innerloop, move restrict parameters to mainloop and add noinline attribute.
commit 6f61b3a5c80011411aa3aebf7983602f7c117566
Author: Kevin Choi <kevin.choi@amd.com>
Date: Thu Aug 14 07:06:51 2025 +0000
Create inner lambda with restrict parameters, add restrict to some parameters
commit 4e17551191980ea7a7e71e9798946cf1dc9f1a1a
Author: aska-0096 <haocwang@amd.com>
Date: Thu Aug 14 03:43:54 2025 +0000
save for debug
commit 5f2c3cfa86c6951208a1cc227fa556704a885a88
Merge: 25f067b4f 165a2723c
Author: aska-0096 <haocwang@amd.com>
Date: Wed Aug 13 02:15:22 2025 +0000
Merge branch 'wip-async-tr-fa' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit 25f067b4f09d6909a05e252c7621124046dfda57
Merge: 447c1c5d6 2ad2f97b7
Author: aska-0096 <haocwang@amd.com>
Date: Wed Aug 13 02:14:26 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit 165a2723c557420b48891cc1ce3434e3675aef5d
Merge: 447c1c5d6 4491739ab
Author: asleepzzz <hanwen.chang@amd.com>
Date: Wed Aug 13 00:34:11 2025 +0800
Merge branch 'develop' into wip-async-tr-fa
commit 447c1c5d6ef0474f9a54c06eea68d65b0346f9b6
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 14:25:50 2025 +0000
refactor blockgemm change, isolate to v2;
commit 8f67083511ff77d31c880f4427d3bdf53a179568
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 09:26:13 2025 +0000
clang format
commit 3f28caa88b9ac9d84029948a7bacf1175cc5a965
Merge: c84662c34 245071bcf
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 09:04:41 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit c84662c345755ec5f3d524fdde4aa951c8f86298
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 08:46:06 2025 +0000
Fix the bug
commit e0647ffa5646f8132529b152af02750c4010013d
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 04:02:41 2025 +0000
fix conflict. disable all v-col instance for fmha fwd
commit 781f98236c376f57591a6d481cc2ee04b36a148b
Merge: 241f3d7dc 6e03d9607
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 03:52:34 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit 241f3d7dc35b2d1cca4eca8ba714581e84f5725e
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 01:53:31 2025 +0000
clang format
commit 8ee83f1c492ae9600a947c4cfe5f7cd25156138f
Merge: 1a629c098 3639befe9
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 12 01:52:52 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit 1a629c09876cc05f0750db7eade1d527dc32a1d3
Merge: f65874e5b b34a029cd
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 11 15:59:40 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit f65874e5b07579d5b734b4c68877679a3ee04dac
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 11 15:37:37 2025 +0000
change the warp setting for hdim32 fmha fwd
commit 7c5f5e65e97486c074ef9a138900ed9aafea547e
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 11 14:21:09 2025 +0000
tempsave, update the blocksync functions
commit beb0950ad8c6b0366a77f5b82e7d5c5f8663b915
Author: aska-0096 <haocwang@amd.com>
Date: Sun Aug 10 06:00:51 2025 +0000
fix bug in pki4
commit 073db2e18af21f1ed1fb3d1f1c15830838df986f
Author: aska-0096 <haocwang@amd.com>
Date: Sat Aug 9 03:25:12 2025 +0000
fix bugs in gemm
commit 01f2d7bd763f64f19861b8a2a861b50bd0aed70a
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 18:35:53 2025 +0000
fix bug on non-gfx950
commit 9a9ca06d59cb1721b4fa70a0d3253fb6b252b37e
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 17:53:19 2025 +0000
fix bug
commit 30de97f473685e0bd5b82f15eee2493d9a05cffd
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 15:42:15 2025 +0000
fix bugs
commit f449cb85a3cfb27bf86525e9c11a2ecf4f7a73a7
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 09:31:01 2025 +0000
fix clangformat with 18.1.3
commit e4cb185c41586d018771a5413efd909d8d53a8c5
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 09:07:40 2025 +0000
remove non-necessary change
commit 498f0d44cfba17287cce8d10855cce5c5de263db
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 09:04:02 2025 +0000
bug fix, clang format;
commit 3cb648cbc4883e6889340d85f48d803a21b9c805
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 08:08:03 2025 +0000
Remove unnecessary changes
commit 9e7ff3b611b7933b65973907a0cae312a15d31c6
Merge: a3c1bfe6d 7f14bd1df
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 07:50:12 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit a3c1bfe6dd64572e4371c7b1b8b5a809aad90c71
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 06:19:31 2025 +0000
remove unnecessary files; rename some files
commit 6c257fa27729c005d539b5b71deeba3703031089
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 8 05:46:18 2025 +0000
merge fa_decode pipeline into fmha_fwd api
commit 26c911b4e5e43aa78fadc5b7c7880421b94d9449
Author: aska-0096 <haocwang@amd.com>
Date: Wed Aug 6 05:58:43 2025 +0000
add __restrict__ to tr load
commit bbad2b979b701533b74f43452ffe0f775e019139
Author: aska-0096 <haocwang@amd.com>
Date: Tue Aug 5 07:23:51 2025 +0000
Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug
commit d7fabd5f765e2a573ddbaf0857ce6f691407e562
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 4 10:27:42 2025 +0000
Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA
commit 9f2c1c5baddaa3a2aa9cd70c4a62401df3c29fd9
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 4 10:02:17 2025 +0000
add vmcnt guard before load ktile
commit f9772f8b6035bc92aa08fb4d092fc21b6b24445c
Author: aska-0096 <haocwang@amd.com>
Date: Mon Aug 4 06:49:01 2025 +0000
Load Q through lds, implement xor;
commit 62bb9f05177dfb8280d6c2be67a88492d6be4838
Author: aska-0096 <haocwang@amd.com>
Date: Fri Aug 1 10:44:54 2025 +0000
small refactor
commit 7cb83c2ab6a87d161259eeb8d5ac3e27ce9587af
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 31 10:25:37 2025 +0000
upgrade prefill pipeline; simple iglp; consistent data produce and consume order
commit 3a85dee389c424490a5101f05c3f4aa3a1ea70be
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 31 05:13:27 2025 +0000
enable larger tile size; upgrade xor pattern
commit a468e59a01d6dd85c105ca30ac249491256c5915
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 30 12:25:33 2025 +0000
remove all lds bankconflict with xor layouts
commit 39ff55cdc377311112100fb24bc013adfd8960c0
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 30 03:51:06 2025 +0000
enable prefill overload operator().
commit a7b152a788e8035c93f8e4cbf317863182665d8f
Author: aska-0096 <haocwang@amd.com>
Date: Fri Jul 25 07:10:01 2025 +0000
fix the lds alignment caused performance regression
commit c4e99bc8f502cd019a754cc9e0043e3d8b9d0f3e
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 23 09:05:57 2025 +0000
remove unnecessary features
commit 9758750801c7fd5a80f654eb982f43b87d674fa3
Author: aska-0096 <haocwang@amd.com>
Date: Tue Jul 22 08:04:05 2025 +0000
tempsave. asynccopy+trload sanity checked
commit 1c4c04d725047357224ebf8a2b94d9010a5651a6
Author: aska-0096 <haocwang@amd.com>
Date: Mon Jul 21 05:55:55 2025 +0000
tempsave, trload+asyncload done
commit 75e68f91fc5a1f35cd5d96901efe15c346a1bd5c
Author: aska-0096 <haocwang@amd.com>
Date: Fri Jul 18 10:04:34 2025 +0000
compile pass
commit d41b5eace939909084d32281710fb81142ad5fec
Merge: 3f86a81ee 33204e15f
Author: aska-0096 <haocwang@amd.com>
Date: Fri Jul 18 05:17:27 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
commit 3f86a81eee75256a78df02032d50814aaa42b038
Author: aska-0096 <haocwang@amd.com>
Date: Fri Jul 18 05:16:39 2025 +0000
tempsave
commit 7d43f7446a9a20773f70e08462393f6c9afb7280
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 10:06:09 2025 +0000
temp save
commit 727629cd9115f1be9c1800bb65a8ea84ff06c250
Merge: aa5da19c9 94bceebc9
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 07:24:32 2025 +0000
Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline
commit 94bceebc96ef4885e0ac861b7793e2e2897481bd
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 03:10:46 2025 +0000
move test_copy into test
commit 8f8bfe7f33884f1588bb7aa1a1d599521f40a30e
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 02:41:31 2025 +0000
remove unnecessary output
commit b1dbcacb1832560c6cc967a079dffce558228f0b
Merge: 5b0d311e6 0eaf3325a
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 02:26:13 2025 +0000
Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into test_copy_fix
commit 5b0d311e649257557a7014c28fcfac0c327b77b5
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 17 02:26:10 2025 +0000
add input validation and bug fix
commit 0eaf3325a8e019402ff12a2402f446f8471f584f
Merge: a66e1d29a 08c5df68a
Author: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Date: Wed Jul 16 11:23:57 2025 -0700
Merge branch 'develop' into test_copy_fix
commit a66e1d29a8cccc17cc8958d970ec7b1281ec8291
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 16 08:55:50 2025 +0000
fix vmcnt shift
commit 197bdcb4827dae6d8460ed375e6265c2c9ddaef0
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 16 08:37:07 2025 +0000
Improve s_waitcnt_imm calculation
commit 3b59e26cf8e0ba573a99a6caa0f37296b23b8bd2
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 16 05:39:50 2025 +0000
fix the s_waitcnt_imm calculation
commit 1c0870089a0e7c78ed71a278bf52d98fc780e482
Merge: d6ee05e36 d9de58c66
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 16 03:57:57 2025 +0000
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into test_copy_fix
commit d6ee05e360dc8426ed2a08a8d6877ebf5cabbd32
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jul 16 03:54:33 2025 +0000
Add block_sync_lds_direct_load utility
commit c037a72040217471f52ee76bed9c07bf5b22aef4
Author: aska-0096 <haocwang@amd.com>
Date: Tue Jul 15 09:39:03 2025 +0000
fix async copytest bug
commit aa5da19c94022449b027e7a57668f2e219f0f171
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jul 10 04:29:33 2025 +0000
temp save, change all instance to 1wave
commit ddd172feb9eb2cb783420a8db6f44d51b350c370
Author: aska-0096 <haocwang@amd.com>
Date: Tue Jul 8 08:37:20 2025 +0000
tempsave, fmha_decode
commit fd90531f4eafdfdbf7df0f3731018fc57dcf4a33
Author: aska-0096 <haocwang@amd.com>
Date: Sat Jun 21 15:02:57 2025 +0000
temp save, waiting for debug
commit 71dd31f15bca01995c8cb0be9e903103f4657181
Author: aska-0096 <haocwang@amd.com>
Date: Thu Jun 19 05:11:52 2025 +0000
save an example for __bf16 type
commit cdf33e079fa7d7d5b03b06550df2356b02041d7b
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jun 18 07:27:24 2025 +0000
fix bwd code
commit d630998dc6751f44097b1e9a239bb5063a793736
Author: aska-0096 <haocwang@amd.com>
Date: Wed Jun 18 06:37:16 2025 +0000
Fix for fwd/bwd kernel build filter
commit d5ec3d0e5768aafed7f77151b2a835e87b9f95ba
Author: Ding, Yi <yi.ding@amd.com>
Date: Tue Aug 19 08:13:18 2025 +0000
Add restrict to avoid unnecessary vmcnt
---------
Co-authored-by: aska-0096 <haocwang@amd.com>
* Add comments for c-stype cast
* Better comments
---------
Co-authored-by: aska-0096 <haocwang@amd.com>
[ROCm/composable_kernel commit: de61e55493]
This commit is contained in:
@@ -1833,14 +1833,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// Use C-style cast to change address space without dropping llvm noalias attribute
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2788,23 +2791,26 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
#define __LDS_ADDR __attribute__((address_space(3)))
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// Use C-style cast to change address space without dropping llvm noalias attribute
|
||||
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
|
||||
#pragma clang diagnostic pop
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
@@ -2812,15 +2818,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
#undef __LDS_ADDR
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -1603,14 +1603,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// Use C-style cast to change address space without dropping llvm noalias attribute
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2606,23 +2609,26 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
#define __LDS_ADDR __attribute__((address_space(3)))
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// Use C-style cast to change address space without dropping llvm noalias attribute
|
||||
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
|
||||
#pragma clang diagnostic pop
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
@@ -2630,15 +2636,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
#undef __LDS_ADDR
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -62,12 +62,12 @@ struct buffer_view<address_space_enum::generic,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
|
||||
@@ -243,7 +243,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
@@ -251,7 +251,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data},
|
||||
@@ -762,12 +762,12 @@ struct buffer_view<address_space_enum::lds,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
|
||||
@@ -1121,12 +1121,12 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
|
||||
@@ -1253,7 +1253,7 @@ template <address_space_enum BufferAddressSpace,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename T,
|
||||
typename BufferSizeType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size)
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
|
||||
{
|
||||
return buffer_view<BufferAddressSpace, T, BufferSizeType, true, Coherence>{p, buffer_size};
|
||||
}
|
||||
@@ -1266,7 +1266,7 @@ template <address_space_enum BufferAddressSpace,
|
||||
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
{
|
||||
return buffer_view<BufferAddressSpace, T, BufferSizeType, false, Coherence>{
|
||||
p, buffer_size, invalid_element_value};
|
||||
|
||||
@@ -449,7 +449,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view =
|
||||
@@ -468,7 +468,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* p,
|
||||
make_naive_tensor_view(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
@@ -491,7 +491,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view_packed(DataType* p,
|
||||
make_naive_tensor_view_packed(DataType* __restrict__ p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
|
||||
@@ -1115,7 +1115,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
{i_n0, 0});
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
{
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
|
||||
q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
@@ -1131,7 +1132,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
@@ -1139,7 +1139,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
FmhaPipeline{}(smem_ptr,
|
||||
q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
@@ -1160,7 +1161,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
operator()(void* smem_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
|
||||
@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
operator()(void* smem_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
|
||||
@@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
else
|
||||
return raw_lse;
|
||||
};
|
||||
template <typename... Ts>
|
||||
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
|
||||
{
|
||||
// LDS allocation
|
||||
// cast to char* to do pointer arithmetic
|
||||
const auto smem_ptr_ = reinterpret_cast<char*>(smem_ptr);
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr_);
|
||||
const auto v_lds_ptr =
|
||||
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
do_lds_ptr0,
|
||||
do_lds_ptr1,
|
||||
q_lds_ptr0,
|
||||
q_lds_ptr1,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
@@ -102,7 +149,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
CK_TILE_DEVICE auto run( //
|
||||
KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr0,
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
@@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
}
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
@@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
|
||||
auto main_body_impl = [&](auto is_prologue_,
|
||||
auto is_epilogue_,
|
||||
QDataType* const __restrict__ q_lds_ptr_curr,
|
||||
QDataType* const __restrict__ q_lds_ptr_next,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
@@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
}
|
||||
};
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
main_body_impl(is_prologue_,
|
||||
is_epilogue_,
|
||||
q_lds_ptr_curr,
|
||||
q_lds_ptr_next,
|
||||
do_lds_ptr_curr,
|
||||
do_lds_ptr_next);
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
|
||||
@@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
return raw_lse;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
|
||||
{
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
do_lds_ptr,
|
||||
q_lds_ptr,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
@@ -109,7 +145,15 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
typename KGradEpilogue,
|
||||
typename VGradEpilogue,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
CK_TILE_DEVICE auto run( //
|
||||
KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr,
|
||||
QDataType* __restrict__ q_lds_ptr,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
@@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
@@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
const bool is_even_loop = (cur_loop % 2 == 0);
|
||||
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
|
||||
auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr,
|
||||
KDataType* __restrict__ k_lds_read_ptr,
|
||||
KDataType* __restrict__ v_lds_write_ptr,
|
||||
KDataType* __restrict__ v_lds_read_ptr) {
|
||||
// move V tile windows
|
||||
block_sync_lds<k_lds_insts>();
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
@@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
};
|
||||
}; // mainloop
|
||||
|
||||
do
|
||||
{
|
||||
mainloop(i_total_loops);
|
||||
bool is_even_loop = i_total_loops % 2 == 0;
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
|
||||
i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user