mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
This reverts commitf2e6edde3b. [ROCm/composable_kernel commit:038ea82315]
This commit is contained in:
@@ -125,8 +125,7 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dvpad},
|
||||
{F_deterministic},
|
||||
{F_trload},
|
||||
{F_maxq},
|
||||
{F_bn0}>;
|
||||
{F_maxq}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -219,10 +218,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
@@ -387,7 +386,6 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]
|
||||
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
]
|
||||
@@ -521,8 +519,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic},
|
||||
{F_bn0}>;
|
||||
{F_deterministic}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -659,17 +656,6 @@ class FmhaBwdApiTrait:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
return "&& (a.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def convert_dq_bn0(self) -> int:
|
||||
return self.tile.F_bn0 if self.deterministic == 't' else 0
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
@@ -694,7 +680,7 @@ class FmhaBwdApiTrait:
|
||||
return 2
|
||||
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
@@ -722,8 +708,7 @@ class FmhaBwdApiPool:
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0)
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@@ -806,9 +791,6 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
|
||||
continue # tr_load cannot work with dpad or dvpad
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
@@ -817,6 +799,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
|
||||
@@ -803,6 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
|
||||
@@ -372,8 +372,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_,
|
||||
ck_tile::index_t kN0>
|
||||
ck_tile::index_t MaxSeqLenQ_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
};
|
||||
@@ -413,10 +412,15 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_,
|
||||
ck_tile::index_t kN0>
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
|
||||
@@ -103,41 +103,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>());
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
return run(k_lds_ptr,
|
||||
v_lds_ptr,
|
||||
@@ -145,10 +131,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
do_lds_ptr1,
|
||||
q_lds_ptr0,
|
||||
q_lds_ptr1,
|
||||
lse_lds_ptr0,
|
||||
lse_lds_ptr1,
|
||||
d_lds_ptr0,
|
||||
d_lds_ptr1,
|
||||
lse_lds_ptr,
|
||||
d_lds_ptr,
|
||||
ds_lds_ptr,
|
||||
bias_lds_ptr,
|
||||
std::forward<Ts>(args)...);
|
||||
@@ -172,10 +156,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
OGradDataType* __restrict__ do_lds_ptr1,
|
||||
QDataType* __restrict__ q_lds_ptr0,
|
||||
QDataType* __restrict__ q_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr0,
|
||||
LSEDataType* __restrict__ lse_lds_ptr1,
|
||||
DDataType* __restrict__ d_lds_ptr0,
|
||||
DDataType* __restrict__ d_lds_ptr1,
|
||||
LSEDataType* __restrict__ lse_lds_ptr,
|
||||
DDataType* __restrict__ d_lds_ptr,
|
||||
GemmDataType* __restrict__ ds_lds_ptr,
|
||||
BiasDataType* __restrict__ bias_lds_ptr,
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
@@ -407,38 +389,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window =
|
||||
make_tile_window(lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window =
|
||||
make_tile_window(d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
@@ -489,31 +471,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
|
||||
decltype(gemm_4.MakeCBlockTile()) dq_acc;
|
||||
|
||||
decltype(load_tile(lse_dram_window)) lse_block_tile;
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body_impl = [&](auto is_prologue_,
|
||||
auto is_epilogue_,
|
||||
QDataType* const __restrict__ q_lds_ptr_curr,
|
||||
QDataType* const __restrict__ q_lds_ptr_next,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr,
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next,
|
||||
LSEDataType* const __restrict__ lse_lds_ptr_curr,
|
||||
LSEDataType* const __restrict__ lse_lds_ptr_next,
|
||||
DDataType* const __restrict__ d_lds_ptr_curr,
|
||||
DDataType* const __restrict__ d_lds_ptr_next
|
||||
|
||||
) mutable {
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
constexpr bool is_main_body = is_prologue && is_epilogue;
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
|
||||
async_load_tile(lse_lds_write_window, lse_dram_window);
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
|
||||
async_load_tile(d_lds_write_window, d_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
@@ -532,13 +510,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
|
||||
d = load_tile(d_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -646,6 +617,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
@@ -700,12 +676,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -743,6 +720,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
{
|
||||
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
|
||||
@@ -771,25 +749,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
};
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
|
||||
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
|
||||
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
|
||||
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
main_body_impl(is_prologue_,
|
||||
is_epilogue_,
|
||||
q_lds_ptr_curr,
|
||||
q_lds_ptr_next,
|
||||
do_lds_ptr_curr,
|
||||
do_lds_ptr_next,
|
||||
lse_lds_ptr_curr,
|
||||
lse_lds_ptr_next,
|
||||
d_lds_ptr_curr,
|
||||
d_lds_ptr_next);
|
||||
do_lds_ptr_next);
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
|
||||
@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window =
|
||||
make_tile_window(lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem>());
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window =
|
||||
make_tile_window(d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
|
||||
|
||||
@@ -194,7 +194,13 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
|
||||
{
|
||||
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentOGrad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -352,30 +358,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kVHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t N0 = MWarp * NWarp;
|
||||
|
||||
constexpr index_t M1 = kMPerBlock;
|
||||
constexpr index_t M0 = get_warp_size() / M1;
|
||||
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
|
||||
"M1 must be a factor of warp size");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N0, M0>,
|
||||
tuple<sequence<M1, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{});
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
|
||||
BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -806,10 +793,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
return lsed_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
@@ -998,16 +984,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
|
||||
{
|
||||
return static_cast<index_t>(max( //
|
||||
sizeof(int) * get_warp_size(),
|
||||
sizeof(typename Problem::LSEDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
|
||||
return sizeof(typename Problem::LSEDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
|
||||
{
|
||||
return GetSmemSizeLSE<Problem>();
|
||||
return sizeof(typename Problem::DDataType) *
|
||||
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1054,9 +1039,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
|
||||
|
||||
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
|
||||
smem_size_lse * 2 + smem_size_d * 2 +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
|
||||
smem_size_d + max(smem_size_bias, smem_size_ds);
|
||||
return max(smem_size_stage0, smem_size_stage1);
|
||||
}
|
||||
|
||||
@@ -1106,8 +1090,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
static constexpr index_t LSE_VMEM_READ = 1;
|
||||
static constexpr index_t D_VMEM_READ = 1;
|
||||
|
||||
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
|
||||
|
||||
// LDS Read
|
||||
static constexpr index_t OGradT_LDS_READ =
|
||||
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
@@ -1134,12 +1116,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t OGradT_LDS_WRITE =
|
||||
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
|
||||
static constexpr index_t LSE_LDS_WRITE = 1;
|
||||
static constexpr index_t D_LDS_WRITE = 1;
|
||||
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
|
||||
|
||||
public:
|
||||
static constexpr index_t TOTAL_VMEM_READ =
|
||||
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
|
||||
|
||||
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
|
||||
{
|
||||
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
|
||||
@@ -1147,7 +1128,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t VMEM_READ_INST =
|
||||
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
|
||||
constexpr index_t MFMA_INST = Gemm0MFMA;
|
||||
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
|
||||
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
|
||||
|
||||
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
|
||||
static_for<0, lcm_inst, 1>{}([&](auto i) {
|
||||
@@ -1180,8 +1161,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
|
||||
// Comp: SGradT x QT
|
||||
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
|
||||
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
|
||||
constexpr index_t MFMA_INST = Gemm3MFMA;
|
||||
|
||||
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
|
||||
@@ -1204,7 +1185,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
// Mem: SGrad, OGrad, D LDS load.
|
||||
// Comp: SGrad x KT
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
|
||||
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
|
||||
constexpr index_t MFMA_INST = Gemm4MFMA;
|
||||
|
||||
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);
|
||||
|
||||
Reference in New Issue
Block a user