mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
update bwd
This commit is contained in:
@@ -187,18 +187,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr0,
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr)
|
||||
{
|
||||
auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr0,
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr)
|
||||
{
|
||||
/*
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
@@ -468,13 +467,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
|
||||
auto main_body_impl = [&](auto is_prologue_,
|
||||
auto is_epilogue_,
|
||||
QDataType* const __restrict__ q_lds_ptr_curr,
|
||||
QDataType* const __restrict__ q_lds_ptr_next,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
@@ -482,19 +480,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -737,6 +735,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
}
|
||||
};
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
main_body_impl(is_prologue_,
|
||||
is_epilogue_,
|
||||
q_lds_ptr_curr,
|
||||
q_lds_ptr_next,
|
||||
do_lds_ptr_curr,
|
||||
do_lds_ptr_next);
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
@@ -766,29 +778,29 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
|
||||
};
|
||||
restrict_body(reinterpret_cast<KDataType*>(smem_ptr_), // k_lds_ptr
|
||||
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>()), // v_lds_ptr
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_), // do_lds_ptr0
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()), // do_lds_ptr1
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr0
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr1
|
||||
reinterpret_cast<LSEDataType*>(
|
||||
restrict_body(reinterpret_cast<KDataType*>(smem_ptr_), // k_lds_ptr
|
||||
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>()), // v_lds_ptr
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_), // do_lds_ptr0
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()), // do_lds_ptr1
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr0
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>() + Policy::template GetSmemSizeQ<Problem>()), // q_lds_ptr1
|
||||
reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>()), // lse_lds_ptr
|
||||
reinterpret_cast<DDataType*>(
|
||||
reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>()), // d_lds_ptr
|
||||
reinterpret_cast<GemmDataType*>(
|
||||
reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>()), // ds_ltr_ptr
|
||||
reinterpret_cast<BiasDataType*>(
|
||||
reinterpret_cast<BiasDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
|
||||
Reference in New Issue
Block a user