Create inner lambda with restrict parameters, add restrict to some parameters

This commit is contained in:
Kevin Choi
2025-08-14 07:06:51 +00:00
parent 3bc45ecbc7
commit 3340408537
3 changed files with 63 additions and 13 deletions

View File

@@ -403,7 +403,7 @@ CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
NumCoord>& __restrict__ tile_window)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,

View File

@@ -187,7 +187,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
auto restrict_body = [&](KDataType* __restrict__ k_lds_ptr,
VDataType* __restrict__ v_lds_ptr,
OGradDataType* __restrict__ do_lds_ptr0,
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr)
{
/*
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
@@ -217,7 +229,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
*/
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
@@ -753,6 +765,35 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
};
restrict_body(reinterpret_cast<KDataType*>(smem_ptr_), // k_lds_ptr
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>()), // v_lds_ptr
reinterpret_cast<OGradDataType*>(smem_ptr_), // do_lds_ptr0
reinterpret_cast<OGradDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()), // do_lds_ptr1
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr0
reinterpret_cast<QDataType*>(smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>()
+ Policy::template GetSmemSizeOGrad<Problem>()), // q_lds_ptr1
reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>()), // lse_lds_ptr
reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>()), // d_lds_ptr
reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>()), // ds_ltr_ptr
reinterpret_cast<BiasDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>())); // bias_ltr_ptr
return make_tuple(dk_acc, dv_acc);
}
};

View File

@@ -856,15 +856,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto mainloop = [&](index_t cur_loop) {
const bool is_even_loop = (cur_loop % 2 == 0);
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
auto innerloop = [&](KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
@@ -1110,8 +1105,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
};
}; // innerloop
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
innerloop(k_lds_write_ptr,
k_lds_read_ptr,
v_lds_write_ptr,
v_lds_read_ptr);
}; // mainloop
do
{
mainloop(i_total_loops);