mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
[ROCm/composable_kernel commit: 79a5d9c10c]
This commit is contained in:
@@ -6,7 +6,7 @@ execute_process(
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -23,7 +23,7 @@ add_custom_command(
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
|
||||
"alibi" : "bias_enum::alibi"
|
||||
}
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
"no" : "t.has_dropout == false",
|
||||
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
|
||||
@@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
BWD_DQDKDV_PIPELINE_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
|
||||
}
|
||||
|
||||
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
|
||||
}
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_tile_{F_idx} = ck_tile::
|
||||
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
|
||||
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
|
||||
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile_{F_idx}>;
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile0_{F_idx}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
{F_dropout},
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
using fmha_dropout_{F_idx} = {F_dropout};
|
||||
|
||||
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
@@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
|
||||
fmha_bwd_shape_{F_idx},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
fmha_bwd_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
false, false>>;
|
||||
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
{F_skpad},
|
||||
{F_dpad}>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
false, false>>;
|
||||
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
{F_skpad},
|
||||
{F_dvpad}>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
|
||||
fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_pipeline_enum},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
@@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVApiTrait:
|
||||
pipeline : str
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
@@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait:
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return f'a.seqlen_q % 256 == 0' # BlockSize
|
||||
return f'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
@@ -256,16 +273,19 @@ class FmhaBwdApiPool:
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
hdim_int = int(hdim)
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
@@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize:
|
||||
F_bhdv : int # v head_dim
|
||||
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
|
||||
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
|
||||
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
|
||||
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
|
||||
F_wm : int # warp size along m (warp size)
|
||||
F_wn : int # warp size along n
|
||||
F_wk : int # warp size along k
|
||||
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
|
||||
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
|
||||
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
|
||||
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
|
||||
F_wm1 : int # warp size along m in gemm1/gemm3
|
||||
F_wn1 : int # warp size along n in gemm1/gemm3
|
||||
F_wk1 : int # warp size along k in gemm1/gemm3
|
||||
F_occupancy : int # occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_pipeline : str
|
||||
mask_impl : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_deterministic : str #
|
||||
F_pipeline : str #
|
||||
mask_impl : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = BOOL_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = DROPOUT_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
if self.F_dbias == 't' : n += '_dbias'
|
||||
@@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
@@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel:
|
||||
spad=self.F_spad,
|
||||
skpad=self.F_skpad,
|
||||
dpad=self.F_dpad,
|
||||
dvpad=self.F_dvpad)
|
||||
dvpad=self.F_dvpad,
|
||||
deterministic=self.F_deterministic
|
||||
)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
|
||||
"qs_ks_vr_dos"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"qs_ks_vr_dos"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"ks_vr"]
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
@@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl)
|
||||
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
@@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_occupancy}>;
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = */ 256,
|
||||
/* BlockSize = */ 64,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
|
||||
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
using fmha_bwd_dot_do_o_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
|
||||
fmha_bwd_dot_do_o_{F_idx}>;
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
using dot_do_o_trait_{F_idx} =
|
||||
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
template <>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
@@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
|
||||
return gen
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_convert_dq_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_bwd_convert_dq_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
|
||||
|
||||
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_spad : str # true/false
|
||||
F_dpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_bm0,
|
||||
F_bn0 = self.F_bn0,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy,
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_deterministic == 't' : n += f'_deterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
tile = d[hdim_str][0]
|
||||
if (mode == "group" and spad == "f"):
|
||||
continue
|
||||
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
|
||||
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
@@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_convert_dq_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
@@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
|
||||
@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
{
|
||||
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
i_perm
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
|
||||
<< std::flush;
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
|
||||
|
||||
std::size_t workspace_size =
|
||||
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
|
||||
|
||||
if(deterministic == 1)
|
||||
{
|
||||
std::cout << "\nDeterministic mode ON: " << workspace_size
|
||||
<< " MByte memory workspace allocated" << std::endl;
|
||||
}
|
||||
|
||||
auto fmha_traits = fmha_bwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f};
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
dk_buf.GetDeviceBuffer(),
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
|
||||
// clang-format on
|
||||
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
@@ -77,6 +77,7 @@ struct fmha_bwd_args
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
@@ -97,6 +98,8 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
@@ -108,6 +111,10 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -117,15 +124,17 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_lsed,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_lsed);
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kHasDropout_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
@@ -354,6 +436,8 @@ struct fmha_bwd_traits
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// self define lse data layout as [batch, nhead, max_seqlen_q]
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
|
||||
@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
|
||||
@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
|
||||
set -x
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "e" "a"; do
|
||||
for dbias in 0 1 ; do
|
||||
for p_drop in 0.0 0.2; do
|
||||
for bias in "n" "a" ; do
|
||||
for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -31,4 +32,5 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
set +x
|
||||
|
||||
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
|
||||
};
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths, typename RightShift>
|
||||
template <typename LowLengths>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
{
|
||||
static constexpr auto type_enum = coord_transform_enum::xor_t;
|
||||
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
|
||||
using UpLengths = LowLengths;
|
||||
|
||||
UpLengths up_lengths_;
|
||||
RightShift right_shift_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
|
||||
const RightShift& right_shift)
|
||||
: up_lengths_{low_lengths}, right_shift_{right_shift}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
|
||||
{
|
||||
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
idx_low(number<0>{}) = idx_up[number<0>{}];
|
||||
|
||||
const auto idx_low_1_tmp =
|
||||
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
|
||||
|
||||
const auto idx_low_1 =
|
||||
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
|
||||
|
||||
idx_low(number<1>{}) = idx_low_1;
|
||||
idx_low(number<1>{}) =
|
||||
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightShift>::value;
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
// MUST be static function
|
||||
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
|
||||
array<index_t, 2> up_vector_lengths = low_vector_lengths;
|
||||
array<index_t, 2> up_vector_strides = low_vector_strides;
|
||||
|
||||
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
|
||||
{
|
||||
if(low_vector_lengths[1] != -1)
|
||||
{
|
||||
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_shift_: ");
|
||||
print(right_shift_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
|
||||
return modulo<Modulus, UpLength>{modulus, up_length};
|
||||
}
|
||||
|
||||
template <typename LowLengths, typename RightShift>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
|
||||
const RightShift& right_shift)
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
|
||||
return xor_t<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
|
||||
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u32
|
||||
// using uint32_t = ...
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
|
||||
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
|
||||
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
|
||||
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
|
||||
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
decltype(sliced_h_lengths), // only need to change the
|
||||
// h_lengths type
|
||||
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
|
||||
// change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
|
||||
@@ -53,6 +53,39 @@ class philox
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
out_tmp[1] = tmp[start_idx + 2];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
|
||||
@@ -8,21 +8,16 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
|
||||
@@ -286,11 +286,226 @@ struct BlockDropout
|
||||
});
|
||||
}
|
||||
|
||||
ck_tile::philox ph;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd;
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = false;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
(void)randval_dram_block_window_tmp;
|
||||
(void)seqlen_qk_start;
|
||||
|
||||
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = true;
|
||||
// true: 32*32 warp gemm
|
||||
// false: 16*16 warp gemm
|
||||
static constexpr bool IsWG32 = IsWG32_;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
|
||||
index_t i_head,
|
||||
index_t nheads,
|
||||
unsigned long long seed,
|
||||
unsigned long long offset,
|
||||
float rp_undrop_,
|
||||
uint8_t p_undrop_in_uint8_t_)
|
||||
: ph(seed,
|
||||
offset + (i_batch * nheads + i_head) * get_warp_size() +
|
||||
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
|
||||
rp_undrop(rp_undrop_),
|
||||
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
|
||||
constexpr index_t kMPerStep = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return MWarp * WG::kM * 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return MWarp * WG::kM;
|
||||
}
|
||||
}();
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = WG::kN;
|
||||
constexpr index_t kN1 = 8;
|
||||
constexpr index_t kN0 = kNPerStep / kN1;
|
||||
|
||||
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
|
||||
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
|
||||
number<kN1>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
|
||||
randval_lds_block_desc_0,
|
||||
ck_tile::make_tuple(
|
||||
make_pass_through_transform(number<kMPerStep>{}),
|
||||
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
|
||||
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return randval_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
|
||||
|
||||
constexpr index_t MIterPerWarp = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
|
||||
// except headdim256.
|
||||
constexpr auto randval_block_inner_part_dstr_encoding = []() {
|
||||
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
|
||||
std::is_same_v<typename BlockGemm::CDataType, float>)
|
||||
{
|
||||
if constexpr(IsWG32)
|
||||
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
else
|
||||
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsWG32)
|
||||
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
|
||||
else
|
||||
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
randval_block_inner_part_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename PComputeDataType,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
|
||||
const index_t start_m0_idx,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
@@ -305,30 +520,177 @@ struct BlockDropout
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval.kThreadElementSpaceSize == 16);
|
||||
// randval tile in LDS
|
||||
auto randval_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
|
||||
|
||||
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
|
||||
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
|
||||
auto randval_lds_window = make_tile_window(
|
||||
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
|
||||
|
||||
// register distribute
|
||||
auto randval_dist_generated =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
|
||||
auto randval_lds_read_window =
|
||||
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
|
||||
randval_lds_window.get_window_lengths(),
|
||||
randval_lds_window.get_window_origin(),
|
||||
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
|
||||
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
|
||||
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
constexpr auto randval_dist_generated_spans =
|
||||
decltype(randval_dist_generated)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
|
||||
});
|
||||
});
|
||||
// save to LDS
|
||||
store_tile(randval_lds_window, randval_dist_generated);
|
||||
block_sync_lds();
|
||||
// read from LDS to register
|
||||
auto randval = load_tile(randval_lds_read_window);
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
|
||||
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
|
||||
constexpr index_t kMPerStep = [&]() {
|
||||
if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
return MWarp * WG::kM * 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return MWarp * WG::kM;
|
||||
}
|
||||
}();
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval = make_static_distributed_tensor<uint8_t>(
|
||||
MakeRandValTileDistribution<BlockGemm, false>());
|
||||
if constexpr(IsWG32)
|
||||
static_assert(randval.kThreadElementSpaceSize == 16);
|
||||
else
|
||||
static_assert(randval.kThreadElementSpaceSize == 4 ||
|
||||
randval.kThreadElementSpaceSize == 8);
|
||||
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
int block_row_start, block_col_start;
|
||||
if constexpr(IsWG32)
|
||||
{
|
||||
block_row_start = (start_m0_idx / WG::kM) + i_m0;
|
||||
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
|
||||
}
|
||||
else
|
||||
{
|
||||
block_row_start = start_m0_idx / 32 + i_m0;
|
||||
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
|
||||
}
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
// generate random number
|
||||
uint8_t* random_uint8_t_;
|
||||
if constexpr(MBwdWG16SingleIterCheck)
|
||||
{
|
||||
uint8_t random_uint8_t[4];
|
||||
// m0t0 ~m0t15/m0t32~m0t47: 0
|
||||
// m0t16~m0t31/m0t48~m0t63: 1
|
||||
// m1t0 ~m1t15/m1t32~m1t47: 2
|
||||
// m1t16~m1t31/m1t48~m1t63: 3
|
||||
const index_t start_idx =
|
||||
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
|
||||
ph.get_random_4x8(
|
||||
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
else if constexpr(MBwdWG16MultiIterCheck)
|
||||
{
|
||||
uint8_t random_uint8_t[8];
|
||||
// t0 ~t15/t32~t47: 0
|
||||
// t16~t31/t48~t63: 1
|
||||
const index_t start_idx = (get_lane_id() >> 4) & 1;
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t,
|
||||
reinterpret_cast<unsigned long long&>(rowcol));
|
||||
random_uint8_t_ = random_uint8_t;
|
||||
}
|
||||
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval(r_idx) = random_uint8_t[i_random_idx++];
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval(r_idx) = random_uint8_t_[i_random_idx++];
|
||||
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
|
||||
idx0.impl_.at(1),
|
||||
idx0.impl_.at(2)>{};
|
||||
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
@@ -337,19 +699,19 @@ struct BlockDropout
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {kMPerStep, 0});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
|
||||
}
|
||||
});
|
||||
if(is_store_randval)
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
|
||||
}
|
||||
@@ -358,7 +720,6 @@ struct BlockDropout
|
||||
ck_tile::philox ph;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,54 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaBwdTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
template <ck_tile::index_t kBlockSize>
|
||||
struct FmhaBwdOGradDotOTilePartitioner
|
||||
{
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
|
||||
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
_SS_(FmhaPipeline::name) +
|
||||
(pn.empty() ? "" : "_" + pn) +
|
||||
(kStoreLSE ? "_lse" : "" ) +
|
||||
(kStoreLSE ? "_lse" : "" ) +
|
||||
(kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o};
|
||||
batch_stride_o,
|
||||
batch_stride_lse_acc};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
const long_index_t batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_lse_acc = query_start;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
batch_stride_v,
|
||||
batch_stride_lse_acc};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
const long_index_t batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_lse_acc = query_start;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
|
||||
141
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
Normal file
141
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
Normal file
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdConvertQGrad
|
||||
{
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
|
||||
static constexpr index_t kM0 = Problem::kM0;
|
||||
static constexpr index_t kN0 = Problem::kN0;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
|
||||
static constexpr index_t kAlignmentQGradAcc =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
// Convert only
|
||||
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
|
||||
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<AccDataType,
|
||||
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
|
||||
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradDramTileDistribution<Problem>());
|
||||
|
||||
auto dq_acc = load_tile(dq_acc_dram_window);
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
}
|
||||
|
||||
// Reduce + Convert
|
||||
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
|
||||
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
index_t nsplits) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<AccDataType,
|
||||
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
|
||||
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
|
||||
|
||||
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
|
||||
clear_tile(dq_acc);
|
||||
|
||||
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
|
||||
index_t i_total_loops = 0;
|
||||
auto dq_acc_buf = load_tile(dq_acc_dram_window);
|
||||
move_tile_window(dq_acc_dram_window, {1, 0, 0});
|
||||
|
||||
do
|
||||
{
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
dq_acc_buf = load_tile(dq_acc_dram_window);
|
||||
move_tile_window(dq_acc_dram_window, {1, 0, 0});
|
||||
|
||||
i_total_loops += 1;
|
||||
} while(i_total_loops < (nsplits - 1));
|
||||
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// declare dq
|
||||
constexpr auto dq_converted_dstr =
|
||||
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
|
||||
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
|
||||
|
||||
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
|
||||
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
|
||||
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
|
||||
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
|
||||
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
|
||||
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,11 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdOGradDotO
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// These templates are not used here.
|
||||
using BlockFmhaBwdOGradDotODefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ false,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ false,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,782 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
|
||||
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
// K, HBM ->LDS ->Reg
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
// Early termination
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
VDataType* v_lds_ptr =
|
||||
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
|
||||
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(kt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
auto v_block_tile = load_tile(v_dram_window);
|
||||
|
||||
store_tile(k_lds_write_window, k_block_tile);
|
||||
shuffle_tile(shuffled_k_block_tile, k_block_tile);
|
||||
store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto kt_reg_tensor = load_tile(kt_lds_read_window);
|
||||
|
||||
store_tile(v_lds_write_window, v_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>()));
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK0>{}),
|
||||
q_lds_window.get_window_origin(),
|
||||
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
|
||||
// QT: Reg -> Reg-> LDS
|
||||
auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
QDataType* qt_lds_ptr =
|
||||
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(qt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dO: HBM ->Reg ->LDS
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>());
|
||||
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
|
||||
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
do_lds_window.get_window_origin(),
|
||||
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
|
||||
// dOT: Reg ->Reg ->LDS
|
||||
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
|
||||
auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(dot_read_lds,
|
||||
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dS: Reg -> Reg -> LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>()));
|
||||
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto ds_lds_read_window =
|
||||
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK4>{}),
|
||||
ds_lds_window.get_window_origin(),
|
||||
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
|
||||
// Bias: HBM ->Reg ->Reg ->LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>()));
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
// BiasGrad
|
||||
// Reg ->LDS ->Reg ->HBM
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto dbias_dram_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto dbias_lds_read_window =
|
||||
make_tile_window(bias_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Hot loop
|
||||
while(i_total_loops < num_total_loop)
|
||||
{
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
auto lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
shuffle_tile(shuffled_q_block_tile, q_block_tile);
|
||||
store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
|
||||
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto q_reg_tensor = load_tile(q_lds_read_window);
|
||||
auto lse = load_tile(lse_lds_read_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto s_acc = SPBlockTileType{};
|
||||
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto p = SPBlockTileType{};
|
||||
constexpr auto p_spans = decltype(p)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
|
||||
}
|
||||
const auto p_gemm = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
p);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(p);
|
||||
}
|
||||
}();
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto do_block_tile = load_tile(do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
auto d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
shuffle_tile(shuffled_do_block_tile, do_block_tile);
|
||||
store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
|
||||
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem,
|
||||
decltype(pt_reg_tensor),
|
||||
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto do_reg_tensor = load_tile(do_lds_read_window);
|
||||
auto d = load_tile(d_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbias = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
ds);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
block_sync_lds();
|
||||
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto ds_reg_tensor = load_tile(ds_lds_read_window);
|
||||
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
|
||||
}
|
||||
});
|
||||
move_tile_window(ds_lds_read_window, {0, -kN0});
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
|
||||
i_total_loops += 1;
|
||||
seqlen_q_step += kM0;
|
||||
}
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,848 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = true;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_kts_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType,
|
||||
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeKT<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto kt_dram_block_window = kt_dram_block_window_tmp;
|
||||
|
||||
auto kt_dram_window = make_tile_window(
|
||||
kt_dram_block_window.get_bottom_tensor_view(),
|
||||
kt_dram_block_window.get_window_lengths(),
|
||||
kt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
|
||||
// load
|
||||
|
||||
auto kt_block_tile = load_tile(kt_dram_window);
|
||||
|
||||
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
|
||||
|
||||
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k & k^t located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ true,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,821 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = false;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "ks_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QDataType,
|
||||
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kVHeaddim ==
|
||||
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto qt_dram_block_window =
|
||||
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
qt_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dot_dram_block_window =
|
||||
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dot_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_q_start});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto qt_dram_window =
|
||||
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
|
||||
qt_dram_block_window.get_window_lengths(),
|
||||
qt_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQTDramTileDistribution<Problem>());
|
||||
|
||||
auto dot_dram_window =
|
||||
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
|
||||
dot_dram_block_window.get_window_lengths(),
|
||||
dot_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradTDramTileDistribution<Problem>());
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
{
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write 0
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(q_dram_window, {0, kK0});
|
||||
|
||||
store_tile(q_lds_window,
|
||||
q_block_tile); // LDS write i + 1
|
||||
q_block_tile = load_tile(q_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kN0, (k0_loops - 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(q_lds_window, q_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(st_acc,
|
||||
q_lds_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
|
||||
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(pt_gemm,
|
||||
sequence<i_k1 * kK1, 0>{},
|
||||
sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(dot_shuffle_tmp, dot);
|
||||
store_tile(dot_lds_window,
|
||||
dot_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(dot_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
|
||||
dot_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
|
||||
{
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
store_tile(do_lds_window, do_block_tile); // LDS write 0
|
||||
do_block_tile = load_tile(do_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(k2_loops > 2)
|
||||
{
|
||||
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
move_tile_window(do_dram_window, {0, kK2});
|
||||
|
||||
store_tile(do_lds_window,
|
||||
do_block_tile); // LDS write i + 1
|
||||
do_block_tile = load_tile(do_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 2) * kK2>{},
|
||||
sequence<kN0, (k2_loops - 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(do_lds_window, do_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_2(dpt_acc,
|
||||
do_lds_window,
|
||||
get_slice_tile(v,
|
||||
sequence<0, (k2_loops - 1) * kK2>{},
|
||||
sequence<kN0, k2_loops * kK2>{}));
|
||||
}
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
|
||||
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
|
||||
block_sync_lds();
|
||||
{
|
||||
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
}
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
if constexpr(k3_loops > 1)
|
||||
{
|
||||
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
|
||||
const auto qt = load_tile(qt_dram_window); // load next Q^T
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(dst_gemm,
|
||||
sequence<i_k3 * kK3, 0>{},
|
||||
sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
shuffle_tile(qt_shuffle_tmp, qt);
|
||||
store_tile(qt_lds_window,
|
||||
qt_shuffle_tmp); // store the prefetch
|
||||
|
||||
move_tile_window(qt_dram_window, {0, kK3});
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
|
||||
qt_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, k located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ false,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,692 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static constexpr bool kQTLoadOnce = false;
|
||||
static constexpr bool kKLoadOnce = true;
|
||||
static constexpr bool kKTLoadOnce = false;
|
||||
static constexpr bool kVLoadOnce = true;
|
||||
static constexpr bool kOGradLoadOnce = true;
|
||||
static constexpr bool kOGradTLoadOnce = false;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad =
|
||||
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
|
||||
static constexpr const char* name = "qs_ks_vr_dos";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename QTDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename KTDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename OGradTDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
#endif
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// QT tile in LDS
|
||||
auto qt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
|
||||
auto qt_lds_window =
|
||||
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
// KT tile in LDS
|
||||
auto kt_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<KDataType*>(smem_ptr),
|
||||
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
|
||||
auto kt_lds_window =
|
||||
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// OGrad tile in LDS
|
||||
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
// OGradT tile in LDS
|
||||
auto dot_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
|
||||
auto dot_lds_window =
|
||||
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
|
||||
|
||||
// SGrad tile in LDS
|
||||
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// BiasT/BiasGradT tile in LDS, use the same size and layout
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>()));
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
auto biast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
auto dbiast_lds_shuffle_window =
|
||||
make_tile_window(biast_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
|
||||
|
||||
auto v = load_tile(v_dram_window); // persistent V register tile
|
||||
|
||||
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
|
||||
|
||||
auto q_dram_block_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto do_dram_block_window =
|
||||
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto dq_dram_block_window =
|
||||
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
auto lse_dram_block_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
auto d_dram_block_window =
|
||||
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_block_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
auto dbias_dram_block_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window.get_bottom_tensor_view(),
|
||||
lse_dram_block_window.get_window_lengths(),
|
||||
lse_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window.get_bottom_tensor_view(),
|
||||
d_dram_block_window.get_window_lengths(),
|
||||
d_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
|
||||
bias_dram_block_window.get_window_lengths(),
|
||||
bias_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto biast_lds_window =
|
||||
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
|
||||
biast_lds_shuffle_window.get_window_lengths(),
|
||||
biast_lds_shuffle_window.get_window_origin(),
|
||||
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kM0 / kK1;
|
||||
constexpr index_t k2_loops = kVHeaddim / kK2;
|
||||
constexpr index_t k3_loops = kM0 / kK3;
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
do
|
||||
{
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
|
||||
// load
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram_block_window.get_bottom_tensor_view(),
|
||||
do_dram_block_window.get_window_lengths(),
|
||||
do_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
|
||||
// window for load
|
||||
|
||||
// STAGE 1, Q@K Gemm0
|
||||
auto st_acc = SPTBlockTileType{};
|
||||
|
||||
auto q_block_tile = load_tile(q_dram_window);
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kN0, (i_k0 + 1) * kK0>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
}
|
||||
|
||||
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(st_acc,
|
||||
get_slice_tile(q_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kN0, k0_loops * kK0>{}));
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
|
||||
block_sync_lds();
|
||||
auto biast_tile = load_tile(biast_lds_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = raw_scale * x + type_convert<AccDataType>(y);
|
||||
#else
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
#endif
|
||||
},
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_lse;
|
||||
}
|
||||
};
|
||||
|
||||
auto pt = SPTBlockTileType{};
|
||||
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
|
||||
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
else
|
||||
{
|
||||
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
#else
|
||||
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
|
||||
}
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
block_sync_lds();
|
||||
store_tile(do_lds_window, do_block_tile); // store the prefetch
|
||||
|
||||
const auto pt_gemm = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
|
||||
pt);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(pt);
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm_1(dv_acc,
|
||||
get_slice_tile(
|
||||
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
|
||||
get_slice_tile(dot_lds_window,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dpt_acc = SPGradTBlockTileType{};
|
||||
clear_tile(dpt_acc); // Initialize PGrad^T
|
||||
|
||||
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
|
||||
block_sync_lds();
|
||||
gemm_2(dpt_acc,
|
||||
get_slice_tile(do_lds_window,
|
||||
sequence<0, i_k2 * kK2>{},
|
||||
sequence<kM0, (i_k2 + 1) * kK2>{}),
|
||||
get_slice_tile(
|
||||
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
const auto d = load_tile(d_dram_window);
|
||||
|
||||
auto dst = SPGradTBlockTileType{};
|
||||
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
|
||||
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = pt[i_j_idx] >= 0;
|
||||
dst(i_j_idx) =
|
||||
pt[i_j_idx] *
|
||||
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbiast = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(dst);
|
||||
}
|
||||
}();
|
||||
store_tile(biast_lds_shuffle_window, dbiast);
|
||||
block_sync_lds();
|
||||
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
|
||||
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_block_window, {kM0, 0});
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
block_sync_lds();
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
|
||||
block_sync_lds();
|
||||
gemm_3(dk_acc,
|
||||
get_slice_tile(
|
||||
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
|
||||
get_slice_tile(qt_lds_window,
|
||||
sequence<0, i_k3 * kK3>{},
|
||||
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc); // Initialize QGrad
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
gemm_4(dq_acc,
|
||||
get_slice_tile(ds_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kM0, (i_k4 + 1) * kK4>{}),
|
||||
get_slice_tile(kt_lds_window,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
|
||||
});
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
update_tile(dq_dram_block_window, dq);
|
||||
|
||||
// move tile windows
|
||||
move_tile_window(q_dram_block_window, {kM0, 0});
|
||||
move_tile_window(dq_dram_block_window, {kM0, 0});
|
||||
move_tile_window(do_dram_block_window, {kM0, 0});
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// KGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
// VGrad Scale
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
|
||||
return ck_tile::make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is v located in regs, q & k & do located in lds.
|
||||
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
|
||||
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
|
||||
/* QTLoadOnce_ = */ false,
|
||||
/* KLoadOnce_ = */ true,
|
||||
/* KTLoadOnce_ = */ false,
|
||||
/* VLoadOnce_ = */ true,
|
||||
/* OGradLoadOnce_ = */ true,
|
||||
/* OGradTLoadOnce_ = */ false>;
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,9 +8,8 @@ namespace ck_tile {
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KSKTSVR = 0,
|
||||
QSKSVROGradS,
|
||||
KSVR,
|
||||
KRKTRVR_IGLP = 0,
|
||||
KRKTRVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,7 +24,9 @@ template <typename QDataType_,
|
||||
typename BiasGradDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdPipelineProblem
|
||||
{
|
||||
@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
|
||||
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename AccDataType_,
|
||||
typename QGradDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kQKHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using QGradDataType = remove_cvref_t<QGradDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
"kBlockSize should be divisible by get_warp_size()");
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kQKHeaddim = kQKHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
struct TileFmhaBwdConvertQGradTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
Normal file
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBRegCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
// M->N Warp
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_>
|
||||
struct BlockGemmARegBRegCRegV1CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmARegBRegCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmARegBRegCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
|
||||
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
|
||||
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
|
||||
|
||||
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
@@ -58,8 +65,15 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
using WarpGemmMfmaDispatcher = typename impl::
|
||||
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user