mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge commit 'de61e554938265a5d17a1bba8c148457125e80cd' into develop
This commit is contained in:
@@ -1115,7 +1115,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
{i_n0, 0});
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
{
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
|
||||
q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
@@ -1131,7 +1132,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
@@ -1139,7 +1139,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
FmhaPipeline{}(smem_ptr,
|
||||
q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
@@ -1160,7 +1161,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
operator()(void* smem_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
|
||||
@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
operator()(void* smem_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
|
||||
@@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
else
|
||||
return raw_lse;
|
||||
};
|
||||
template <typename... Ts>
|
||||
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
|
||||
{
|
||||
// LDS allocation
|
||||
// cast to char* to do pointer arithmetic
|
||||
const auto smem_ptr_ = reinterpret_cast<char*>(smem_ptr);
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr_);
|
||||
const auto v_lds_ptr =
|
||||
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
do_lds_ptr0,
|
||||
do_lds_ptr1,
|
||||
q_lds_ptr0,
|
||||
q_lds_ptr1,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
@@ -102,7 +149,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
CK_TILE_DEVICE auto run( //
|
||||
KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr0,
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
@@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
}
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
@@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
|
||||
auto main_body_impl = [&](auto is_prologue_,
|
||||
auto is_epilogue_,
|
||||
QDataType* const __restrict__ q_lds_ptr_curr,
|
||||
QDataType* const __restrict__ q_lds_ptr_next,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
@@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
}
|
||||
};
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
main_body_impl(is_prologue_,
|
||||
is_epilogue_,
|
||||
q_lds_ptr_curr,
|
||||
q_lds_ptr_next,
|
||||
do_lds_ptr_curr,
|
||||
do_lds_ptr_next);
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
|
||||
@@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
return raw_lse;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
|
||||
{
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
do_lds_ptr,
|
||||
q_lds_ptr,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
@@ -109,7 +145,15 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
typename KGradEpilogue,
|
||||
typename VGradEpilogue,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
CK_TILE_DEVICE auto run( //
|
||||
KDataType* __restrict__ k_lds_ptr,
|
||||
VDataType* __restrict__ v_lds_ptr,
|
||||
OGradDataType* __restrict__ do_lds_ptr,
|
||||
QDataType* __restrict__ q_lds_ptr,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
@@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
@@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
const bool is_even_loop = (cur_loop % 2 == 0);
|
||||
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
|
||||
auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr,
|
||||
KDataType* __restrict__ k_lds_read_ptr,
|
||||
KDataType* __restrict__ v_lds_write_ptr,
|
||||
KDataType* __restrict__ v_lds_read_ptr) {
|
||||
// move V tile windows
|
||||
block_sync_lds<k_lds_insts>();
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
@@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
};
|
||||
}; // mainloop
|
||||
|
||||
do
|
||||
{
|
||||
mainloop(i_total_loops);
|
||||
bool is_even_loop = i_total_loops % 2 == 0;
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
|
||||
i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user