mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
[ROCm/composable_kernel commit: 79a5d9c10c]
This commit is contained in:
@@ -6,7 +6,7 @@ execute_process(
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -23,7 +23,7 @@ add_custom_command(
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
|
||||
"alibi" : "bias_enum::alibi"
|
||||
}
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
"no" : "t.has_dropout == false",
|
||||
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
|
||||
@@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
BWD_DQDKDV_PIPELINE_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
|
||||
}
|
||||
|
||||
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
|
||||
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
|
||||
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
|
||||
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
|
||||
}
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_tile_{F_idx} = ck_tile::
|
||||
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
|
||||
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
|
||||
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile_{F_idx}>;
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile0_{F_idx}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
{F_dropout},
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
using fmha_dropout_{F_idx} = {F_dropout};
|
||||
|
||||
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
@@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
|
||||
fmha_bwd_shape_{F_idx},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
fmha_bwd_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
false, false>>;
|
||||
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
{F_skpad},
|
||||
{F_dpad}>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
false, false>>;
|
||||
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
{F_skpad},
|
||||
{F_dvpad}>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
|
||||
fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_pipeline_enum},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
@@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVApiTrait:
|
||||
pipeline : str
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
@@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait:
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return f'a.seqlen_q % 256 == 0' # BlockSize
|
||||
return f'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
@@ -256,16 +273,19 @@ class FmhaBwdApiPool:
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
hdim_int = int(hdim)
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
@@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize:
|
||||
F_bhdv : int # v head_dim
|
||||
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
|
||||
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
|
||||
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
|
||||
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
|
||||
F_wm : int # warp size along m (warp size)
|
||||
F_wn : int # warp size along n
|
||||
F_wk : int # warp size along k
|
||||
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
|
||||
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
|
||||
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
|
||||
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
|
||||
F_wm1 : int # warp size along m in gemm1/gemm3
|
||||
F_wn1 : int # warp size along n in gemm1/gemm3
|
||||
F_wk1 : int # warp size along k in gemm1/gemm3
|
||||
F_occupancy : int # occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_pipeline : str
|
||||
mask_impl : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_deterministic : str #
|
||||
F_pipeline : str #
|
||||
mask_impl : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = BOOL_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = DROPOUT_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
if self.F_dbias == 't' : n += '_dbias'
|
||||
@@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
@@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel:
|
||||
spad=self.F_spad,
|
||||
skpad=self.F_skpad,
|
||||
dpad=self.F_dpad,
|
||||
dvpad=self.F_dvpad)
|
||||
dvpad=self.F_dvpad,
|
||||
deterministic=self.F_deterministic
|
||||
)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
|
||||
"qs_ks_vr_dos"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"qs_ks_vr_dos"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
||||
"ks_vr"]
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
@@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl)
|
||||
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
@@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_occupancy}>;
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = */ 256,
|
||||
/* BlockSize = */ 64,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
|
||||
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
using fmha_bwd_dot_do_o_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
|
||||
fmha_bwd_dot_do_o_{F_idx}>;
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
using dot_do_o_trait_{F_idx} =
|
||||
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
template <>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
@@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
|
||||
return gen
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_convert_dq_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_bwd_convert_dq_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
|
||||
|
||||
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_spad : str # true/false
|
||||
F_dpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_bm0,
|
||||
F_bn0 = self.F_bn0,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy,
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_deterministic == 't' : n += f'_deterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
tile = d[hdim_str][0]
|
||||
if (mode == "group" and spad == "f"):
|
||||
continue
|
||||
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
|
||||
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
@@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_convert_dq_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
@@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
|
||||
@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
{
|
||||
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
i_perm
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
|
||||
<< std::flush;
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
|
||||
|
||||
std::size_t workspace_size =
|
||||
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
|
||||
|
||||
if(deterministic == 1)
|
||||
{
|
||||
std::cout << "\nDeterministic mode ON: " << workspace_size
|
||||
<< " MByte memory workspace allocated" << std::endl;
|
||||
}
|
||||
|
||||
auto fmha_traits = fmha_bwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f};
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
dk_buf.GetDeviceBuffer(),
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
|
||||
// clang-format on
|
||||
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
@@ -77,6 +77,7 @@ struct fmha_bwd_args
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
@@ -97,6 +98,8 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
@@ -108,6 +111,10 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -117,15 +124,17 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_lsed,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_lsed);
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kHasDropout_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
@@ -354,6 +436,8 @@ struct fmha_bwd_traits
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// self define lse data layout as [batch, nhead, max_seqlen_q]
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
|
||||
@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
|
||||
@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
|
||||
set -x
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "e" "a"; do
|
||||
for dbias in 0 1 ; do
|
||||
for p_drop in 0.0 0.2; do
|
||||
for bias in "n" "a" ; do
|
||||
for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -31,4 +32,5 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
set +x
|
||||
|
||||
Reference in New Issue
Block a user