[CK_TILE] FMHA avoid unnecessary vmcnt0 (#2715)

* FMHA avoid unnecessary vmcnt0

Squashed commit of the following:

commit 7bdf6a7eef
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 22 03:15:51 2025 +0000

    merge develop and solve conflicts

commit f21e916a8c
Merge: a7dd2a7d1 0db21053e
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 22 03:15:21 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into vmcnt0issue

commit a7dd2a7d13
Author: Ding, Yi <yi.ding@amd.com>
Date:   Tue Aug 19 02:17:43 2025 +0000

    update bwd

commit 380aa8f311
Author: Kevin Choi <kevin.choi@amd.com>
Date:   Mon Aug 18 19:36:38 2025 +0000

    add restrict to applicable functions

commit b85daba2a3
Author: Ding, Yi <yi.ding@amd.com>
Date:   Mon Aug 18 02:07:03 2025 +0000

    bwd filter

commit 75c4b9372f
Author: Kevin Choi <kevin.choi@amd.com>
Date:   Sat Aug 16 08:15:23 2025 +0000

    remove noinline attr as it causes a lot more s_waitcnt's

commit 598e3fec41
Author: Kevin Choi <kevin.choi@amd.com>
Date:   Thu Aug 14 12:11:17 2025 +0000

    remove innerloop, move restrict parameters to mainloop and add noinline attribute.

commit 3340408537
Author: Kevin Choi <kevin.choi@amd.com>
Date:   Thu Aug 14 07:06:51 2025 +0000

    Create inner lambda with restrict parameters, add restrict to some parameters

commit 3bc45ecbc7
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Aug 14 03:43:54 2025 +0000

    save for debug

commit de4db6c4c5
Merge: 108abf00e 68694cb78
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Aug 13 02:15:22 2025 +0000

    Merge branch 'wip-async-tr-fa' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 108abf00e0
Merge: 0810799e2 0f42a92fc
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Aug 13 02:14:26 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 68694cb781
Merge: 0810799e2 20288caa2
Author: asleepzzz <hanwen.chang@amd.com>
Date:   Wed Aug 13 00:34:11 2025 +0800

    Merge branch 'develop' into wip-async-tr-fa

commit 0810799e25
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 14:25:50 2025 +0000

    refactor blockgemm change, isolate to v2;

commit fd1eb323af
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 09:26:13 2025 +0000

    clang format

commit 75f6f6bac4
Merge: bcc05eee6 8e1eb0c1e
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 09:04:41 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit bcc05eee62
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 08:46:06 2025 +0000

    Fix the bug

commit 96d24497f5
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 04:02:41 2025 +0000

    fix conflict. disable all v-col instance for fmha fwd

commit 1716171be4
Merge: 1c9800790 4fde1646e
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 03:52:34 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 1c98007901
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 01:53:31 2025 +0000

    clang format

commit f43e903b1d
Merge: 3868ddd70 a7badc6ec
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 12 01:52:52 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 3868ddd708
Merge: 498d234ab 191c62967
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 11 15:59:40 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 498d234ab8
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 11 15:37:37 2025 +0000

    change the warp setting for hdim32 fmha fwd

commit b86f7786e2
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 11 14:21:09 2025 +0000

    tempsave, update the blocksync functions

commit 7b8052d7ca
Author: aska-0096 <haocwang@amd.com>
Date:   Sun Aug 10 06:00:51 2025 +0000

    fix bug in pki4

commit 76cbbb84a2
Author: aska-0096 <haocwang@amd.com>
Date:   Sat Aug 9 03:25:12 2025 +0000

    fix bugs in gemm

commit 8c101ccb88
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 18:35:53 2025 +0000

    fix bug on non-gfx950

commit efb8549279
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 17:53:19 2025 +0000

    fix bug

commit 729e8785fb
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 15:42:15 2025 +0000

    fix bugs

commit 250dc13c75
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 09:31:01 2025 +0000

    fix clangformat with 18.1.3

commit 106edeecd9
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 09:07:40 2025 +0000

    remove non-necessary change

commit 78edd7303b
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 09:04:02 2025 +0000

    bug fix, clang format;

commit 3b9fb6af38
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 08:08:03 2025 +0000

    Remove unnecessary changes

commit 6bb57c2c57
Merge: 1ecee378d ab2602683
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 07:50:12 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit 1ecee378d5
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 06:19:31 2025 +0000

    remove unnecessary files; rename some files

commit b4640a9de6
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 8 05:46:18 2025 +0000

    merge fa_decode pipeline into fmha_fwd api

commit fe63a646a4
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Aug 6 05:58:43 2025 +0000

    add __restrict__ to tr load

commit 414cad667b
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Aug 5 07:23:51 2025 +0000

    Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug

commit 0d12fc944f
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 4 10:27:42 2025 +0000

    Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA

commit 4f31847de1
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 4 10:02:17 2025 +0000

    add vmcnt guard before load ktile

commit 746f4ccb99
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Aug 4 06:49:01 2025 +0000

    Load Q through lds, implement xor;

commit 2d4e73d2b4
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Aug 1 10:44:54 2025 +0000

    small refactor

commit a28b6e67fe
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 31 10:25:37 2025 +0000

    upgrade prefill pipeline; simple iglp; consistent data produce and consume order

commit 75cba48682
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 31 05:13:27 2025 +0000

    enable larger tile size; upgrade xor pattern

commit 69890afc98
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 30 12:25:33 2025 +0000

    remove all lds bankconflict with xor layouts

commit 8dacc35c4c
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 30 03:51:06 2025 +0000

    enable prefill overload operator().

commit 13bcc913de
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Jul 25 07:10:01 2025 +0000

    fix the lds alignment caused performance regression

commit af28123cec
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 23 09:05:57 2025 +0000

    remove unnecessary features

commit 14e0ab70c6
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Jul 22 08:04:05 2025 +0000

    tempsave. asynccopy+trload sanity checked

commit 1b468bac0b
Author: aska-0096 <haocwang@amd.com>
Date:   Mon Jul 21 05:55:55 2025 +0000

    tempsave, trload+asyncload done

commit afd96d8180
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Jul 18 10:04:34 2025 +0000

    compile pass

commit 5616551115
Merge: ae39c84f5 095393276
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Jul 18 05:17:27 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

commit ae39c84f55
Author: aska-0096 <haocwang@amd.com>
Date:   Fri Jul 18 05:16:39 2025 +0000

    tempsave

commit 94b6430489
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 10:06:09 2025 +0000

    temp save

commit 7e330553dc
Merge: 18669925c 804f77dce
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 07:24:32 2025 +0000

    Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline

commit 804f77dce5
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 03:10:46 2025 +0000

    move test_copy into test

commit 21627d7ca7
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 02:41:31 2025 +0000

    remove unnecessary output

commit 287792c44a
Merge: a4221db30 21fd7e953
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 02:26:13 2025 +0000

    Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into test_copy_fix

commit a4221db304
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 17 02:26:10 2025 +0000

    add input validation and bug fix

commit 21fd7e9538
Merge: d6df7bf85 6e76b8205
Author: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Date:   Wed Jul 16 11:23:57 2025 -0700

    Merge branch 'develop' into test_copy_fix

commit d6df7bf851
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 16 08:55:50 2025 +0000

    fix vmcnt shift

commit 40e039e4e4
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 16 08:37:07 2025 +0000

    Improve s_waitcnt_imm calculation

commit c30f8b709b
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 16 05:39:50 2025 +0000

    fix the s_waitcnt_imm calculation

commit ec0a45b29f
Merge: e5cc4af80 6b09f0823
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 16 03:57:57 2025 +0000

    Merge branch 'develop' of https://github.com/ROCm/composable_kernel into test_copy_fix

commit e5cc4af808
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jul 16 03:54:33 2025 +0000

    Add block_sync_lds_direct_load utility

commit eea58629cf
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Jul 15 09:39:03 2025 +0000

    fix async copytest bug

commit 18669925cc
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jul 10 04:29:33 2025 +0000

    temp save, change all instance to 1wave

commit 18686cfe5b
Author: aska-0096 <haocwang@amd.com>
Date:   Tue Jul 8 08:37:20 2025 +0000

    tempsave, fmha_decode

commit 47565f21a5
Author: aska-0096 <haocwang@amd.com>
Date:   Sat Jun 21 15:02:57 2025 +0000

    temp save, waiting for debug

commit e0a634ef97
Author: aska-0096 <haocwang@amd.com>
Date:   Thu Jun 19 05:11:52 2025 +0000

    save an example for __bf16 type

commit 4bd5fd4a3c
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jun 18 07:27:24 2025 +0000

    fix bwd code

commit 69809d9513
Author: aska-0096 <haocwang@amd.com>
Date:   Wed Jun 18 06:37:16 2025 +0000

    Fix for fwd/bwd kernel build filter

commit d5ec3d0e5768aafed7f77151b2a835e87b9f95ba
Author: Ding, Yi <yi.ding@amd.com>
Date:   Tue Aug 19 08:13:18 2025 +0000

    Add restrict to avoid unnecessary vmcnt

---------

Co-authored-by: aska-0096 <haocwang@amd.com>

* Add comments for c-stype cast

* Better comments

---------

Co-authored-by: aska-0096 <haocwang@amd.com>
This commit is contained in:
Yi DING
2025-08-25 20:55:12 +08:00
committed by GitHub
parent c71d7ddd74
commit de61e55493
10 changed files with 217 additions and 151 deletions

View File

@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
operator()(void* smem_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(

View File

@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
operator()(void* smem_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(

View File

@@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
else
return raw_lse;
};
template <typename... Ts>
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
{
// LDS allocation
// cast to char* to do pointer arithmetic
const auto smem_ptr_ = reinterpret_cast<char*>(smem_ptr);
const auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr_);
const auto v_lds_ptr =
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
do_lds_ptr0,
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
@@ -102,7 +149,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
CK_TILE_DEVICE auto run( //
KDataType* __restrict__ k_lds_ptr,
VDataType* __restrict__ v_lds_ptr,
OGradDataType* __restrict__ do_lds_ptr0,
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
@@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
@@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
}
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
@@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
@@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_prologue)
{
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
async_load_tile(q_lds_write_window, q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
async_load_tile(do_lds_write_window, do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
}
if constexpr(is_epilogue)
{
@@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
@@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
move_tile_window(dq_dram_window, {kM0, 0});
}
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
return raw_lse;
};
template <typename... Ts>
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
{
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr =
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
do_lds_ptr,
q_lds_ptr,
lse_lds_ptr,
d_lds_ptr,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
@@ -109,7 +145,15 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
typename KGradEpilogue,
typename VGradEpilogue,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
CK_TILE_DEVICE auto run( //
KDataType* __restrict__ k_lds_ptr,
VDataType* __restrict__ v_lds_ptr,
OGradDataType* __restrict__ do_lds_ptr,
QDataType* __restrict__ q_lds_ptr,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
@@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
@@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
{seqlen_kv_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr =
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
@@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&](index_t cur_loop) {
const bool is_even_loop = (cur_loop % 2 == 0);
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
move_tile_window(v_dram_window, {kN0, 0});
@@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
};
}; // mainloop
do
{
mainloop(i_total_loops);
bool is_even_loop = i_total_loops % 2 == 0;
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
i_total_loops++;
} while(i_total_loops < num_total_loop);