[CK_TILE] FA bwd kernels optimization (#1397)

* tmp save

* fix batch deterministic bugs

* fix group deterministic bugs

* codegen update

* reorder files

* bias support

* hd256 bias support

* bwd smoke test update

* simplify convert dq

* fix hd256 dropout scratch

* do{}while() -> while(){}

* comments

* remove FmhaBwdTilePartitioner

* save clear_tile

* refactor dropout

* code cleanup

* code cleanup

* comments

* fix epilogue problem

* fix fwd dropout

* group convert_dq opt

* fix dq alignment

* Do not store storerandval in bwd for flash attention integration

* fix hd32 error and boost performance

* revert

* Remove duplicated WarpGemm definitions in the policy file

* dropout patch for mrepeat 16*16

* code sync up

* dq_acc stride

* dq_acc stride stuff

* codegen update

* fwd dropout revert

* fix hd128 scratches and boost performance

* receipt 3 for simplified smoke test

* more strides for fa integration

* fix hd64 scratches and boost performance

* non-iglp pipeline for headdim padding cases

* dpad same as dvpad for flash attention integration

* unpadded lse&d for group mode

* Support unpad layout for group lse

* Support unpad lse layout for splitkv

* Fix stride for splitkv kernel

* fix unpadded lse issue in fwd splitkv

* comment

* solve lds read&write conflicts

* rename

* bias rename

* tile index revert

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>

[ROCm/composable_kernel commit: 79a5d9c10c]
This commit is contained in:
Dan Yao
2024-08-17 04:40:10 +08:00
committed by GitHub
parent dffd5eacc0
commit 14402bb211
43 changed files with 5515 additions and 4222 deletions

View File

@@ -6,7 +6,7 @@ execute_process(
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
@@ -23,7 +23,7 @@ add_custom_command(
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)

View File

@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi"
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
MODE_MAP = {
"batch" : "false",
"group" : "true"

View File

@@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
}
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
}
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
@@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile0_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
false,
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
{F_skpad},
{F_dpad}>>;
using fmha_bwd_dv_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
{F_skpad},
{F_dvpad}>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_pipeline_enum},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
@@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
@@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
@@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait:
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize
return f'a.seqlen_q % 64 == 0'
@property
def skcheck(self) -> str:
@@ -256,16 +273,19 @@ class FmhaBwdApiPool:
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
@@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize:
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
F_wm1 : int # warp size along m in gemm1/gemm3
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
@dataclass
class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_pipeline : str
mask_impl : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_deterministic : str #
F_pipeline : str #
mask_impl : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = BOOL_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = DROPOUT_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property
def name(self) -> str:
@@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_dbias == 't' : n += '_dbias'
@@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel:
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout == 't' : n += '_dropout'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_deterministic == 't' : n += '_deterministic'
return n
@property
@@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel:
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad)
dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_vr"]
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
return None
@@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
@@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl)
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
@@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_trait_{F_idx} =
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
/* BlockSize = */ 64,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
@@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dpad == 't' : n += 'd'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for dtype in DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
tile = d[hdim_str][0]
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k)
return gen
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
@@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
@@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)

View File

@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset();
}
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0)
{
@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
<< std::flush;
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v,
@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f};
p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
s_randval,
{drop_seed, drop_offset}};
}();
@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on
q_host_refs.push_back(q_host_ref);
@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};

View File

@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,

View File

@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}

View File

@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "e" "a"; do
for dbias in 0 1 ; do
for p_drop in 0.0 0.2; do
for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
@@ -31,4 +32,5 @@ done
done
done
done
done
set +x

View File

@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
return xor_t<LowLengths>{low_lengths};
}
template <typename LowLength, typename OffsetLength>

View File

@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));

View File

@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,

View File

@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private:
struct ull2
{

View File

@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"

View File

@@ -286,11 +286,226 @@ struct BlockDropout
});
}
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph(seed,
offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
}
else
{
if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
@@ -305,30 +520,177 @@ struct BlockDropout
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16);
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start, block_col_start;
if constexpr(IsWG32)
{
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++];
constexpr auto p_idx0 =
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
@@ -337,19 +699,19 @@ struct BlockDropout
});
});
// save to Global
if(is_store_randval)
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0});
}
});
if(is_store_randval)
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
}
});
if(is_store_randval)
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
}
@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -1,54 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile

View File

@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
batch_offset_lse = query_start;
}
if constexpr(kHasDropout)
{

View File

@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) +
(kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE)
{
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;

View File

@@ -0,0 +1,141 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdConvertQGrad
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
// Convert only
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
}
// Reduce + Convert
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
index_t nsplits) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
index_t i_total_loops = 0;
auto dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
do
{
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
i_total_loops += 1;
} while(i_total_loops < (nsplits - 1));
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
});
});
});
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
store_tile(dq_dram_block_window_tmp, dq);
}
};
} // namespace ck_tile

View File

@@ -4,11 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdOGradDotO
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -0,0 +1,782 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
// K, HBM ->LDS ->Reg
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return make_tuple(dk_acc, dv_acc);
}
}
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
store_tile(k_lds_write_window, k_block_tile);
shuffle_tile(shuffled_k_block_tile, k_block_tile);
store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto kt_reg_tensor = load_tile(kt_lds_read_window);
store_tile(v_lds_write_window, v_block_tile);
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
// QT: Reg -> Reg-> LDS
auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_s_lds_read_window =
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
bias_lds_write_window.get_window_lengths(),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
clear_tile(dv_acc);
clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
while(i_total_loops < num_total_loop)
{
auto q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
store_tile(q_lds_window, q_block_tile);
shuffle_tile(shuffled_q_block_tile, q_block_tile);
store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
store_tile(lse_lds_write_window, lse_block_tile);
block_sync_lds();
auto q_reg_tensor = load_tile(q_lds_read_window);
auto lse = load_tile(lse_lds_read_window);
block_sync_lds();
// STAGE 1, Q@K Gemm0
auto s_acc = SPBlockTileType{};
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
s_acc,
bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto p = SPBlockTileType{};
constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
}
else
{
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
}
});
});
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
}
const auto p_gemm = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
p);
}
else
{
return cast_tile<GemmDataType>(p);
}
}();
// STAGE 3, P^T@OGrad^T Gemm1
auto do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
store_tile(do_lds_window, do_block_tile);
shuffle_tile(shuffled_do_block_tile, do_block_tile);
store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
store_tile(d_lds_write_window, d_block_tile);
block_sync_lds();
auto dot_reg_tensor = load_tile(dot_lds_read_window);
block_sync_lds();
Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
// STAGE 4, OGrad@V Gemm2
auto do_reg_tensor = load_tile(do_lds_read_window);
auto d = load_tile(d_lds_read_window);
block_sync_lds();
auto dp_acc = SPGradBlockTileType{};
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
// STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{};
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbias = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
ds);
}
else
{
return cast_tile<BiasGradDataType>(ds);
}
}();
store_tile(bias_lds_write_window, dbias);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_reg_tensor = load_tile(qt_lds_read_window);
block_sync_lds();
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
// STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
{
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1;
seqlen_q_step += kM0;
}
// Results Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
return make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -1,848 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = true;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_kts_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType,
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto kt_dram_block_window = kt_dram_block_window_tmp;
auto kt_dram_window = make_tile_window(
kt_dram_block_window.get_bottom_tensor_view(),
kt_dram_block_window.get_window_lengths(),
kt_dram_block_window.get_window_origin(),
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
// load
auto kt_block_tile = load_tile(kt_dram_window);
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -1,821 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -1,692 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = true;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = true;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "qs_ks_vr_dos";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
// QT tile in LDS
auto qt_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
// OGradT tile in LDS
auto dot_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
block_sync_lds();
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds();
store_tile(do_lds_window, do_block_tile); // store the prefetch
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
get_slice_tile(dot_lds_window,
sequence<0, i_k1 * kK1>{},
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
block_sync_lds();
});
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc); // Initialize PGrad^T
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
get_slice_tile(do_lds_window,
sequence<0, i_k2 * kK2>{},
sequence<kM0, (i_k2 + 1) * kK2>{}),
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
});
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst);
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
get_slice_tile(qt_lds_window,
sequence<0, i_k3 * kK3>{},
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
block_sync_lds();
});
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile

View File

@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KSKTSVR = 0,
QSKSVROGradS,
KSVR,
KRKTRVR_IGLP = 0,
KRKTRVR,
};
} // namespace ck_tile

View File

@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_,
typename FmhaDropout_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename AccDataType_,
typename QGradDataType_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile

View File

@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBRegCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBRegCRegV1DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
}
};
} // namespace ck_tile

View File

@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

View File

@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

View File

@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;

View File

@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec)
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b
auto c_vec = Impl{}(
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec)
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec)
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
});

View File

@@ -15,7 +15,8 @@ template <typename AType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC>
bool TransposeC,
bool SwizzleA = false>
struct WarpGemmMfmaDispatcher;
// clang-format off
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
@@ -58,8 +65,15 @@ template <typename AType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC>
using WarpGemmMfmaDispatcher = typename impl::
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
bool TransposeC,
bool SwizzleA = false>
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
CType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA>::Type;
} // namespace ck_tile