Revert "[CK_TILE] FMHA BWD Enable Tile 16x192 (#2741)" (#2757)

This reverts commit ead4447b20.
This commit is contained in:
asleepzzz
2025-08-28 22:50:42 +08:00
committed by GitHub
parent 4a49dac7c6
commit 038ea82315
6 changed files with 114 additions and 173 deletions

View File

@@ -103,41 +103,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
Policy::template GetSmemSizeD<Problem>());
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
@@ -145,10 +131,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr0,
lse_lds_ptr1,
d_lds_ptr0,
d_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
@@ -172,10 +156,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr0,
LSEDataType* __restrict__ lse_lds_ptr1,
DDataType* __restrict__ d_lds_ptr0,
DDataType* __restrict__ d_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
@@ -407,38 +389,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
@@ -489,31 +471,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next,
LSEDataType* const __restrict__ lse_lds_ptr_curr,
LSEDataType* const __restrict__ lse_lds_ptr_next,
DDataType* const __restrict__ d_lds_ptr_curr,
DDataType* const __restrict__ d_lds_ptr_next
) mutable {
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
if constexpr(is_prologue)
{
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
async_load_tile(lse_lds_write_window, lse_dram_window);
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
async_load_tile(d_lds_write_window, d_dram_window);
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
@@ -532,13 +510,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
}
if constexpr(is_epilogue)
{
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
lse = load_tile(lse_lds_read_window);
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
@@ -646,6 +617,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_prologue)
{
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
}
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
@@ -700,12 +676,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
store_tile(ds_lds_window, ds_gemm);
}
s_waitcnt</*vmcnt=*/0>();
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
if constexpr(is_prologue)
{
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -743,6 +720,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
@@ -771,25 +749,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next,
lse_lds_ptr_curr,
lse_lds_ptr_next,
d_lds_ptr_curr,
d_lds_ptr_next);
do_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(

View File

@@ -194,7 +194,13 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentOGrad<Problem>();
}
template <typename Problem>
@@ -352,30 +358,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kVHeaddim>();
}
template <typename Problem>
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N0 = MWarp * NWarp;
constexpr index_t M1 = kMPerBlock;
constexpr index_t M0 = get_warp_size() / M1;
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
"M1 must be a factor of warp size");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, M0>,
tuple<sequence<M1, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1>,
sequence<1>>{});
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
BlockGemm>();
}
template <typename Problem>
@@ -806,10 +793,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
return lsed_lds_block_desc;
}
template <typename Problem>
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
@@ -998,16 +984,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
return static_cast<index_t>(max( //
sizeof(int) * get_warp_size(),
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
return sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
return GetSmemSizeLSE<Problem>();
return sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
@@ -1054,9 +1039,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
smem_size_lse * 2 + smem_size_d * 2 +
max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
smem_size_d + max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0, smem_size_stage1);
}
@@ -1106,8 +1090,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
@@ -1134,12 +1116,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
public:
static constexpr index_t TOTAL_VMEM_READ =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
@@ -1147,7 +1128,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
static_for<0, lcm_inst, 1>{}([&](auto i) {
@@ -1180,8 +1161,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
@@ -1204,7 +1185,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);