mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Create inner lambda with restrict parameters, add restrict to some parameters
This commit is contained in:
@@ -403,7 +403,7 @@ CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
|
||||
@@ -187,7 +187,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
|
||||
auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr0,
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr)
|
||||
{
|
||||
/*
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
@@ -217,7 +229,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
|
||||
|
||||
*/
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
@@ -753,6 +765,35 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
};
|
||||
restrict_body(reinterpret_cast<KDataType*>(smem_ptr_), // k_lds_ptr
|
||||
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>()), // v_lds_ptr
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_), // do_lds_ptr0
|
||||
reinterpret_cast<OGradDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()), // do_lds_ptr1
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr0
|
||||
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
|
||||
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr1
|
||||
reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>()), // lse_lds_ptr
|
||||
reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>()), // d_lds_ptr
|
||||
reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>()), // ds_ltr_ptr
|
||||
reinterpret_cast<BiasDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>())); // bias_ltr_ptr
|
||||
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -856,15 +856,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
const bool is_even_loop = (cur_loop % 2 == 0);
|
||||
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
auto innerloop = [&](KDataType* __restrict__ k_lds_write_ptr,
|
||||
KDataType* __restrict__ k_lds_read_ptr,
|
||||
KDataType* __restrict__ v_lds_write_ptr,
|
||||
KDataType* __restrict__ v_lds_read_ptr) {
|
||||
|
||||
// move V tile windows
|
||||
block_sync_lds<k_lds_insts>();
|
||||
@@ -1110,8 +1105,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
};
|
||||
|
||||
}; // innerloop
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
innerloop(k_lds_write_ptr,
|
||||
k_lds_read_ptr,
|
||||
v_lds_write_ptr,
|
||||
v_lds_read_ptr);
|
||||
|
||||
}; // mainloop
|
||||
do
|
||||
{
|
||||
mainloop(i_total_loops);
|
||||
|
||||
Reference in New Issue
Block a user