[CK_TILE] FMHA BWD Enable Tile 16x192 (#2741)

* 16x192

* Use buffer_load_lds for lse/d

* Dispatch & cleanup

* Avoid zeroing dq & fix

* fix

[ROCm/composable_kernel commit: ead4447b20]
This commit is contained in:
Yi DING
2025-08-28 18:54:18 +08:00
committed by GitHub
parent 146a41f0d8
commit 8b537fb883
6 changed files with 173 additions and 114 deletions

View File

@@ -125,7 +125,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dvpad},
{F_deterministic},
{F_trload},
{F_maxq}>;
{F_maxq},
{F_bn0}>;
#include <iostream>
@@ -218,10 +219,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
@@ -386,6 +387,7 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]
@@ -519,7 +521,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
{F_deterministic},
{F_bn0}>;
#include <iostream>
@@ -656,6 +659,17 @@ class FmhaBwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
@property
def extra_cond(self) -> str:
if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
return "&& (a.seqlen_k <= 256)"
else:
return ""
@property
def convert_dq_bn0(self) -> int:
return self.tile.F_bn0 if self.deterministic == 't' else 0
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
@@ -680,7 +694,7 @@ class FmhaBwdApiTrait:
return 2
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
@@ -708,7 +722,8 @@ class FmhaBwdApiPool:
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
F_convert_dq_bn0=trait.convert_dq_bn0)
i += 1
return inners
@@ -791,6 +806,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
continue # tr_load cannot work with dpad or dvpad
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
@@ -799,9 +817,6 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:

View File

@@ -803,7 +803,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();

View File

@@ -372,7 +372,8 @@ template <ck_tile::index_t HDim_,
bool kPadDv_,
bool kIsDeterministic_,
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
ck_tile::index_t MaxSeqLenQ_,
ck_tile::index_t kN0>
struct fmha_bwd_dq_dk_dv_traits_
{
};
@@ -412,15 +413,10 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
bool kIsDeterministic_,
ck_tile::index_t kN0>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>

View File

@@ -103,27 +103,41 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
@@ -131,8 +145,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
lse_lds_ptr0,
lse_lds_ptr1,
d_lds_ptr0,
d_lds_ptr1,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
@@ -156,8 +172,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
LSEDataType* __restrict__ lse_lds_ptr0,
LSEDataType* __restrict__ lse_lds_ptr1,
DDataType* __restrict__ d_lds_ptr0,
DDataType* __restrict__ d_lds_ptr1,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
@@ -389,38 +407,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
@@ -471,27 +489,31 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
OGradDataType* const __restrict__ do_lds_ptr_next,
LSEDataType* const __restrict__ lse_lds_ptr_curr,
LSEDataType* const __restrict__ lse_lds_ptr_next,
DDataType* const __restrict__ d_lds_ptr_curr,
DDataType* const __restrict__ d_lds_ptr_next
) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
if constexpr(is_prologue)
{
lse_block_tile = load_tile(lse_dram_window);
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
async_load_tile(lse_lds_write_window, lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_block_tile = load_tile(d_dram_window);
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
async_load_tile(d_lds_write_window, d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
@@ -510,6 +532,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
}
if constexpr(is_epilogue)
{
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
lse = load_tile(lse_lds_read_window);
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
@@ -617,11 +646,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_prologue)
{
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
}
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
@@ -676,13 +700,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
s_waitcnt</*vmcnt=*/0>();
block_sync_lds();
if constexpr(is_prologue)
{
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -720,7 +743,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
@@ -749,17 +771,25 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next);
do_lds_ptr_next,
lse_lds_ptr_curr,
lse_lds_ptr_next,
d_lds_ptr_curr,
d_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(

View File

@@ -194,13 +194,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentOGrad<Problem>();
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
}
template <typename Problem>
@@ -358,11 +352,30 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kVHeaddim>();
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
BlockGemm>();
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N0 = MWarp * NWarp;
constexpr index_t M1 = kMPerBlock;
constexpr index_t M0 = get_warp_size() / M1;
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
"M1 must be a factor of warp size");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, M0>,
tuple<sequence<M1, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1>,
sequence<1>>{});
}
template <typename Problem>
@@ -793,9 +806,10 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
return lsed_lds_block_desc;
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
@@ -984,15 +998,16 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return static_cast<index_t>(max( //
sizeof(int) * get_warp_size(),
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
return sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return GetSmemSizeLSE<Problem>();
}
template <typename Problem>
@@ -1039,8 +1054,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
smem_size_d + max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
smem_size_lse * 2 + smem_size_d * 2 +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0, smem_size_stage1);
}
@@ -1090,6 +1106,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
@@ -1116,11 +1134,12 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
public:
static constexpr index_t TOTAL_VMEM_READ =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
@@ -1128,7 +1147,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
static_for<0, lcm_inst, 1>{}([&](auto i) {
@@ -1161,8 +1180,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
@@ -1185,7 +1204,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);