mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
This reverts commit ead4447b20.
This commit is contained in:
@@ -103,41 +103,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>());
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
@@ -145,10 +131,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
do_lds_ptr1,
|
||||
q_lds_ptr0,
|
||||
q_lds_ptr1,
|
||||
lse_lds_ptr0,
|
||||
lse_lds_ptr1,
|
||||
d_lds_ptr0,
|
||||
d_lds_ptr1,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
@@ -172,10 +156,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr0,
|
||||
LSEDataType* __restrict__ lse_lds_ptr1,
|
||||
DDataType* __restrict__ d_lds_ptr0,
|
||||
DDataType* __restrict__ d_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
@@ -407,38 +389,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window =
|
||||
make_tile_window(lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window =
|
||||
make_tile_window(d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
@@ -489,31 +471,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
|
||||
decltype(gemm_4.MakeCBlockTile()) dq_acc;
|
||||
|
||||
decltype(load_tile(lse_dram_window)) lse_block_tile;
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body_impl = [&](auto is_prologue_,
|
||||
auto is_epilogue_,
|
||||
QDataType* const __restrict__ q_lds_ptr_curr,
|
||||
QDataType* const __restrict__ q_lds_ptr_next,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next,
|
||||
LSEDataType* const __restrict__ lse_lds_ptr_curr,
|
||||
LSEDataType* const __restrict__ lse_lds_ptr_next,
|
||||
DDataType* const __restrict__ d_lds_ptr_curr,
|
||||
DDataType* const __restrict__ d_lds_ptr_next
|
||||
|
||||
) mutable {
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
constexpr bool is_main_body = is_prologue && is_epilogue;
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
|
||||
async_load_tile(lse_lds_write_window, lse_dram_window);
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
|
||||
async_load_tile(d_lds_write_window, d_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
@@ -532,13 +510,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
|
||||
d = load_tile(d_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -646,6 +617,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
@@ -700,12 +676,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -743,6 +720,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
{
|
||||
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
|
||||
@@ -771,25 +749,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
};
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
|
||||
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
|
||||
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
|
||||
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
main_body_impl(is_prologue_,
|
||||
is_epilogue_,
|
||||
q_lds_ptr_curr,
|
||||
q_lds_ptr_next,
|
||||
do_lds_ptr_curr,
|
||||
do_lds_ptr_next,
|
||||
lse_lds_ptr_curr,
|
||||
lse_lds_ptr_next,
|
||||
d_lds_ptr_curr,
|
||||
d_lds_ptr_next);
|
||||
do_lds_ptr_next);
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
|
||||
@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window =
|
||||
make_tile_window(lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window =
|
||||
make_tile_window(d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
|
||||
|
||||
@@ -194,7 +194,13 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
|
||||
{
|
||||
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentOGrad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -352,30 +358,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kVHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t N0 = MWarp * NWarp;
|
||||
|
||||
constexpr index_t M1 = kMPerBlock;
|
||||
constexpr index_t M0 = get_warp_size() / M1;
|
||||
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
|
||||
"M1 must be a factor of warp size");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N0, M0>,
|
||||
tuple<sequence<M1, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{});
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
|
||||
BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -806,10 +793,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
return lsed_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
@@ -998,16 +984,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
|
||||
{
|
||||
return static_cast<index_t>(max( //
|
||||
sizeof(int) * get_warp_size(),
|
||||
sizeof(typename Problem::LSEDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
|
||||
return sizeof(typename Problem::LSEDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
|
||||
{
|
||||
return GetSmemSizeLSE<Problem>();
|
||||
return sizeof(typename Problem::DDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1054,9 +1039,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
|
||||
|
||||
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
|
||||
smem_size_lse * 2 + smem_size_d * 2 +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
|
||||
smem_size_d + max(smem_size_bias, smem_size_ds);
|
||||
return max(smem_size_stage0, smem_size_stage1);
|
||||
}
|
||||
|
||||
@@ -1106,8 +1090,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
static constexpr index_t LSE_VMEM_READ = 1;
|
||||
static constexpr index_t D_VMEM_READ = 1;
|
||||
|
||||
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
|
||||
|
||||
// LDS Read
|
||||
static constexpr index_t OGradT_LDS_READ =
|
||||
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
@@ -1134,12 +1116,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t OGradT_LDS_WRITE =
|
||||
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
|
||||
static constexpr index_t LSE_LDS_WRITE = 1;
|
||||
static constexpr index_t D_LDS_WRITE = 1;
|
||||
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
|
||||
|
||||
public:
|
||||
static constexpr index_t TOTAL_VMEM_READ =
|
||||
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
|
||||
|
||||
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
|
||||
{
|
||||
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
|
||||
@@ -1147,7 +1128,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t VMEM_READ_INST =
|
||||
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
|
||||
constexpr index_t MFMA_INST = Gemm0MFMA;
|
||||
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
|
||||
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
|
||||
|
||||
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
|
||||
static_for<0, lcm_inst, 1>{}([&](auto i) {
|
||||
@@ -1180,8 +1161,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
|
||||
// Comp: SGradT x QT
|
||||
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
|
||||
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
|
||||
constexpr index_t MFMA_INST = Gemm3MFMA;
|
||||
|
||||
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
|
||||
@@ -1204,7 +1185,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
// Mem: SGrad, OGrad, D LDS load.
|
||||
// Comp: SGrad x KT
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
|
||||
constexpr index_t MFMA_INST = Gemm4MFMA;
|
||||
|
||||
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);
|
||||
|
||||
Reference in New Issue
Block a user