diff --git a/CMakeLists.txt b/CMakeLists.txt index 2039948a12..8a08ddd19c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,17 +62,14 @@ if (DTYPES) endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8) set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") - if (GPU_TARGETS MATCHES "gfx94") - add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8) - set(CK_ENABLE_FP8 "ON") - set(CK_ENABLE_BF8 "ON") - endif() + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") endif() #for f8/bf8_t type diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index efd47f3851..f82207b920 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -26,7 +26,7 @@ execute_process( execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 ) # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -43,7 +43,7 @@ add_custom_command( add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") @@ -75,11 +75,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) - list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) - list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) # conditionally enable call to the fwd_splitkv API in fmha_fwd example if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 4adf079d71..493b2177e3 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -66,6 +66,22 @@ BIAS_CHECK_MAP = { "alibi" : "bias_enum::alibi" } +DROPOUT_MAP = { + "no" : "ck_tile::BlockDropoutBwd", + "dropout_wg32" : "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + "dropout_wg16" : "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" +} + +DROPOUT_CHECK_MAP = { + "no" : "t.has_dropout == false", + "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", +} + ROPE_MAP = { "no" : "ck_tile::RotaryEmbeddingEnum::NONE", "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0df115dc3d..096394c0c9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import * BWD_DQDKDV_PIPELINE_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", + "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP", + "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", } BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", + "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP", + "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", } FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_tile_{F_idx} = ck_tile:: + sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; +using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; +using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // G0&G2 -> GSdP // G1&G3 -> GdKV // G4 -> GdQ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + fmha_block_warps0_{F_idx}, + fmha_warp_tile0_{F_idx}, + fmha_block_warps1_{F_idx}, + fmha_warp_tile1_{F_idx}, + fmha_block_warps0_{F_idx}, + fmha_warp_tile0_{F_idx}, + fmha_block_warps1_{F_idx}, + fmha_warp_tile1_{F_idx}, + fmha_block_warps2_{F_idx}, + fmha_warp_tile0_{F_idx}>; using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - {F_dbias}, - false, - {F_dropout}, - false, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + false, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; +using fmha_dropout_{F_idx} = {F_dropout}; using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::QDataType, @@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::BiasGradDataType, fmha_bwd_shape_{F_idx}, {F_mode}, + {F_deterministic}, fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, fmha_bwd_trait_{F_idx}>; -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< - fmha_bwd_pipeline_problem_{F_idx}>; +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}; -using fmha_bwd_dk_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - false, false>>; +using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + {F_skpad}, + {F_dpad}>>; -using fmha_bwd_dv_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - false, false>>; +using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + {F_skpad}, + {F_dvpad}>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = - ck_tile::FmhaBwdDQDKDVKernel, - fmha_bwd_pipeline_{F_idx}, - fmha_bwd_dk_epilogue_{F_idx}, - fmha_bwd_dv_epilogue_{F_idx}>; + ck_tile::FmhaBwdDQDKDVKernel; -using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + {F_pipeline_enum}, + fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, + {F_bias}, + {F_dbias}, + {F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_deterministic}>; #include -template<> +template <> float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} -template<> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) {{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); }} -template<> +template <> std::string fmha_bwd_dq_dk_dv_get_name_() {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; @@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API=""" #include -template +template float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} ); }} @@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - r = fmha_bwd_(s, a); + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_(s, a); return r; }} """ @dataclass class FmhaBwdDQDKDVApiTrait: - pipeline : str + pipeline : str # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + deterministic : str def scheck(self, spad1 : str) -> str: if self.mode == 'group': @@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait: elif self.spad == 't' and spad1 == 't': return f'a.seqlen_q % {self.bm0} != 0' elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 256 == 0' # BlockSize + return f'a.seqlen_q % 64 == 0' @property def skcheck(self) -> str: @@ -256,16 +273,19 @@ class FmhaBwdApiPool: per_hdim_case=str() for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): traits=self.dq_dk_dv_pool[dtype][hdim] + hdim_int = int(hdim) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' for spad1 in ["t", "f"]: - if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) @@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize: F_bhdv : int # v head_dim F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 - F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 - F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 - F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 - F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 - F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 - F_rk2 : int # number of warps along gemm-k (not used) in gemm4 - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k + F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4 + F_rk2 : int # number of warps along k seqlen (not used) in gemm4 + F_wm0 : int # warp size along m in gemm0/gemm2/gemm4 + F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 + F_wk0 : int # warp size along k in gemm0/gemm2/gemm4 + F_wm1 : int # warp size along m in gemm1/gemm3 + F_wn1 : int # warp size along n in gemm1/gemm3 + F_wk1 : int # warp size along k in gemm1/gemm3 F_occupancy : int # occupancy @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" @dataclass class FmhaBwdDQDKDVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # - F_dbias : str # - F_dropout : str # - F_mask : str # value from MASK_MAP - F_mode : str # value from MODE_MAP - F_pipeline : str - mask_impl : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_deterministic : str # + F_pipeline : str # + mask_impl : str # @property def template(self) -> str: return FMHA_BWD_KERNEL_HEADER + \ FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bk1 = self.F_tile.F_bk1, - F_bk2 = self.F_tile.F_bk2, - F_bk3 = self.F_tile.F_bk3, - F_bk4 = self.F_tile.F_bk4, - F_bhdq = self.F_tile.F_bhdq, - F_bhdv = self.F_tile.F_bhdv, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_rm2 = self.F_tile.F_rm2, - F_rn2 = self.F_tile.F_rn2, - F_rk2 = self.F_tile.F_rk2, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BIAS_MAP[self.F_bias], - F_dbias = BOOL_MAP[self.F_dbias], - F_dropout = BOOL_MAP[self.F_dropout], - F_occupancy = self.F_tile.F_occupancy, - F_mask = get_mask_map(self.mask_impl)[self.F_mask], - F_mode = MODE_MAP[self.F_mode], + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = DROPOUT_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_deterministic = BOOL_MAP[self.F_deterministic], F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) @property def name(self) -> str: @@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel: if n != '' : n = 'p' + n return n pn = pad_name() - n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' if pn != '' : n += f'_{pn}' if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_dbias == 't' : n += '_dbias' @@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel: if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_dropout == 't' : n += '_dropout' + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + if self.F_deterministic == 't' : n += '_deterministic' return n @property @@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel: spad=self.F_spad, skpad=self.F_skpad, dpad=self.F_dpad, - dvpad=self.F_dvpad) + dvpad=self.F_dvpad, + deterministic=self.F_deterministic + ) # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "ks_vr"] + '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: return None @@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): tile = d[hdim_str][0] ppl = d[hdim_str][1] hdim = int(hdim_str) @@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue + if ("wg32" in dropout): + continue + if (dpad == "t" or dvpad == "t"): + ppl = d[hdim_str][2] k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl) + F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + if receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) @@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, - {F_dvpad}, - {F_occupancy}>; +using fmha_bwd_dot_do_o_trait_{F_idx} = + ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>; using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 256, + /* BlockSize = */ 64, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; -using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< - fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; +using fmha_bwd_dot_do_o_{F_idx} = + typename ck_tile::BlockFmhaBwdOGradDotO; using fmha_bwd_dot_do_o_kernel_{F_idx} = - ck_tile::FmhaBwdOGradDotOKernel, - fmha_bwd_dot_do_o_{F_idx}>; + ck_tile::FmhaBwdOGradDotOKernel; -using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; +using dot_do_o_trait_{F_idx} = + fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; #include -template<> +template <> float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} -template<> +template <> void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); }} -template<> +template <> std::string fmha_bwd_dot_do_o_get_name_() {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; @@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: return gen +FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_convert_dq_trait_{F_idx} = + ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>; + +using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + {F_bm0}, + {F_bn0}, + {F_hdim}, + {F_mode}, + {F_deterministic}, + fmha_bwd_convert_dq_trait_{F_idx}>; + +using fmha_bwd_convert_dq_{F_idx} = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_{F_idx} = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + {F_spad}, + {F_dpad}, + {F_deterministic}>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdConvertQGradKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_spad : str # true/false + F_dpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int # + F_deterministic : str # + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_bm0, + F_bn0 = self.F_bn0, + F_spad = BOOL_MAP[self.F_spad], + F_dpad = BOOL_MAP[self.F_dpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy, + F_deterministic = BOOL_MAP[self.F_deterministic]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dpad == 't' : n += 'd' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + if self.F_deterministic == 't' : n += f'_deterministic' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + tile = d[hdim_str][0] + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, + F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) + gen.append(k) + + return gen + def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) +def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) @@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_ kernels = get_bwd_dot_do_o_blobs() for kernel in kernels: write_single_bwd_dot_do_o_kernel(kernel, output_dir) + kernels = get_bwd_convert_dq_blobs() + for kernel in kernels: + write_single_bwd_convert_dq_kernel(kernel, output_dir) api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) @@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: with file_path.open('a') as f: kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + kernels = get_bwd_convert_dq_blobs() for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index b1249b5eda..efae4e284a 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[]) .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("deterministic", + "0", + "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " + "will not be used"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) hdim_v = hdim_q; - if(hdim_q % 2 != 0 || hdim_v % 2 != 0) - { - std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; - return false; - } bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim @@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } - int stream_warmup = arg_parser.get_int("warmup"); - int stream_repeat = arg_parser.get_int("repeat"); - bool kname = arg_parser.get_bool("kname"); + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + bool deterministic = arg_parser.get_bool("deterministic"); ck_tile::stream_config stream_config{nullptr, true, @@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64; + const ck_tile::index_t nsplits = + deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); ck_tile::HostTensor lse_host( - std::array{batch, nhead, max_seqlen_q}); + std::array{shape_batch, nhead, shape_seqlen_q}); ck_tile::HostTensor d_host( - std::array{batch, nhead, max_seqlen_q}); + std::array{shape_batch, nhead, shape_seqlen_q}); ck_tile::HostTensor randval_host( p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1}); @@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser) use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor dq_acc_host( + i_perm + ? std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} + : std::array{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); if(init_method == 0) { @@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias - << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask - << std::flush; + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval + << ", deterministic:" << deterministic << ", mask:" << mask << std::flush; + + std::size_t workspace_size = + dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024); + + if(deterministic == 1) + { + std::cout << "\nDeterministic mode ON: " << workspace_size + << " MByte memory workspace allocated" << std::endl; + } auto fmha_traits = fmha_bwd_traits{hdim_q, hdim_v, @@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, bias.type, use_dbias, - p_drop > 0.0f}; + p_drop > 0.0f, + s_randval, + deterministic}; auto fmha_args = [&]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, @@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; + const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q; const ck_tile::index_t nhead_stride_dbias = (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); // setup batch_stride_* arguments @@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t split_stride_dq_acc = + (shape_batch * nhead * shape_seqlen_q * hdim_q); return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), @@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser) dk_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), + dq_acc_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), nullptr, @@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_o, stride_randval, stride_do, + stride_q, // stride_dq_acc + stride_q, // stride_dq stride_dk, stride_dv, stride_dbias, @@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser) nhead_stride_randval, nhead_stride_do, nhead_stride_lsed, + nhead_stride_q, // nhead_stride_dq_acc + nhead_stride_q, // nhead_stride_dq + nhead_stride_k, // nhead_stride_dk + nhead_stride_v, // nhead_stride_dv nhead_stride_dbias, batch_stride_q, batch_stride_k, @@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_randval, batch_stride_do, batch_stride_lsed, + batch_stride_q, // batch_stride_dq_acc + batch_stride_q, // batch_stride_dq batch_stride_dk, batch_stride_dv, batch_stride_dbias, + split_stride_dq_acc, mask.left, mask.right, static_cast(mask.type), p_drop, p_undrop, - s_randval, {drop_seed, drop_offset}}; }(); @@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); - lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); }); // clang-format on q_host_refs.push_back(q_host_ref); @@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser) lse_buf.ToDevice(lse_host.data()); dq_buf.SetZero(); dbias_buf.SetZero(); + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 0c6b468951..aea42515dc 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -77,6 +77,7 @@ struct fmha_bwd_args void* dk_ptr; void* dv_ptr; void* dbias_ptr; + void* dq_acc_ptr; const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; @@ -97,6 +98,8 @@ struct fmha_bwd_args ck_tile::index_t stride_o; ck_tile::index_t stride_randval; ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; ck_tile::index_t stride_dbias; @@ -108,6 +111,10 @@ struct fmha_bwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; ck_tile::index_t nhead_stride_dbias; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -117,15 +124,17 @@ struct fmha_bwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; float p_drop; float p_undrop; - bool s_randval; std::tuple drop_seed_offset; }; @@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.do_ptr, args.d_ptr, args.rand_val_ptr, - args.dq_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, + args.dq_acc_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, @@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, + args.stride_dq_acc, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, args.nhead_stride_dbias, - args.batch_stride_lsed, + args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, - args.s_randval, args.drop_seed_offset); } else @@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.do_ptr, args.d_ptr, args.rand_val_ptr, - args.dq_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, + args.dq_acc_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, @@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, + args.stride_dq_acc, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, args.nhead_stride_dbias, args.batch_stride_q, args.batch_stride_k, @@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, + args.batch_stride_dq_acc, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, + args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, - args.s_randval, args.drop_seed_offset); } }(); @@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) args.stride_o, args.nhead_stride_do, args.nhead_stride_o, - args.nhead_stride_lsed, - args.batch_stride_lsed); + args.nhead_stride_lsed); } else { // create batch mode kernel arguments @@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) + { + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.split_stride_dq_acc); + } + else + { // create batch mode kernel arguments + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.batch_stride_dq, + args.batch_stride_dq_acc, + args.split_stride_dq_acc); + } + }(); + + dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template + bool kPadDv_, + bool kIsDeterministic_> struct fmha_bwd_dq_dk_dv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_ static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; }; template @@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dot_do_o_get_name_(); +template +struct fmha_bwd_convert_dq_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_convert_dq_get_name_(); + // This is the public API, will be generated by script struct fmha_bwd_traits { @@ -354,6 +436,8 @@ struct fmha_bwd_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_dbias; bool has_dropout; + bool is_store_randval; + bool is_deterministic; // TODO: padding check is inside this api }; float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 7a2559e633..fc3bf2b0f0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -609,16 +609,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor lse_acc_host( 1 < num_splits || use_kvcache - ? std::array{num_splits, batch, nhead, max_seqlen_q} + ? std::array{num_splits, shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( 1 < num_splits || use_kvcache ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); - // self define lse data layout as [batch, nhead, max_seqlen_q] + // batch mode of lse data layout is [batch, nhead, seqlen_q] + // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{batch, nhead, max_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -874,8 +875,8 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = max_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -890,13 +891,13 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) - const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); + const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); args.q_ptr = q_buf.GetDeviceBuffer(); @@ -1439,8 +1440,9 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach( - [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + }); cur_pass = ck_tile::check_err(lse_host_result, lse_host_ref, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index be8016bebf..871ec2b523 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -303,7 +303,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, - args.batch_stride_lse, args.window_size_left, args.window_size_right, args.mask_type, @@ -486,9 +485,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_o_acc, args.nhead_stride_lse, args.nhead_stride_o, - args.batch_stride_lse_acc, args.batch_stride_o_acc, - args.batch_stride_lse, args.split_stride_lse_acc, args.split_stride_o_acc); } diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index d6830aa2ec..dbb592820e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -11,18 +11,19 @@ COMMON_ARGS='-v=1' set -x for prec in "fp16" "bf16" ; do for perm in 0 1 ; do -for hdim in 32 64 128 ; do +for hdim in 32 64 128 256 ; do for mode in 0 1 ; do -for bias in "n" "e" "a"; do -for dbias in 0 1 ; do -for p_drop in 0.0 0.2; do +for bias in "n" "a" ; do +for dbias in 0 ; do +for p_drop in 0.0 0.2 ; do +for deterministic in 0 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done @@ -31,4 +32,5 @@ done done done done +done set +x diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 71602e5d13..5c7e489804 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1> }; // 2D XOR, NOTE: "xor" is a keyword -template +template struct xor_t : public base_transform<2, 2> { static constexpr auto type_enum = coord_transform_enum::xor_t; @@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2> using UpLengths = LowLengths; UpLengths up_lengths_; - RightShift right_shift_; - CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} + CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {} - CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, - const RightShift& right_shift) - : up_lengths_{low_lengths}, right_shift_{right_shift} - { - } + CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {} CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() { @@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2> idx_low(number<0>{}) = idx_up[number<0>{}]; - const auto idx_low_1_tmp = - (idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; - - const auto idx_low_1 = - (idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp; - - idx_low(number<1>{}) = idx_low_1; + idx_low(number<1>{}) = + idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]); } template @@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2> CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { - return ck_tile::is_known_at_compile_time::value && - ck_tile::is_known_at_compile_time::value; + return ck_tile::is_known_at_compile_time::value; } // MUST be static function @@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2> array up_vector_lengths = low_vector_lengths; array up_vector_strides = low_vector_strides; - if constexpr(ck_tile::is_known_at_compile_time::value) - { - if(low_vector_lengths[1] != -1) - { - up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_)); - } - } - return make_tuple(up_vector_lengths, up_vector_strides); } @@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2> print(up_lengths_); printf(", "); - // - printf("right_shift_: "); - print(right_shift_); - printf("}"); } }; @@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, return modulo{modulus, up_length}; } -template -CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, - const RightShift& right_shift) +template +CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths) { - return xor_t{low_lengths, right_shift}; + return xor_t{low_lengths}; } template diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index c23c12f295..3ef066a3eb 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); +// u32 +// using uint32_t = ... +using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); +using uint32x4_t = uint32_t __attribute__((ext_vector_type(4))); +using uint32x8_t = uint32_t __attribute__((ext_vector_type(8))); +using uint32x16_t = uint32_t __attribute__((ext_vector_type(16))); +using uint32x32_t = uint32_t __attribute__((ext_vector_type(32))); +using uint32x64_t = uint32_t __attribute__((ext_vector_type(64))); + // i16 // using int16_t = ... using int16x2_t = int16_t __attribute__((ext_vector_type(2))); diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 42a30232fb..24c932f0a6 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( return make_tuple( make_static_tile_distribution( tile_distribution_encoding, // only need to + // change the + // h_lengths type typename Encoding::Ps2RHssMajor, typename Encoding::Ps2RHssMinor, typename Encoding::Ys2RHsMajor, diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp index c49f44ae48..87abf5cc18 100644 --- a/include/ck_tile/core/utility/philox_rand.hpp +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -53,6 +53,39 @@ class philox out_tmp[3] = tmp_ph.w; } + CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out, + const unsigned long long subsequence, + const index_t start_idx) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32x4_t tmp; + tmp[0] = tmp_ph.x; + tmp[1] = tmp_ph.y; + tmp[2] = tmp_ph.z; + tmp[3] = tmp_ph.w; + uint32_t* out_tmp = reinterpret_cast(&out[0]); + out_tmp[0] = tmp[start_idx]; + out_tmp[1] = tmp[start_idx + 2]; + } + + CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out, + const unsigned long long subsequence, + const index_t start_idx) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32x4_t tmp; + tmp[0] = tmp_ph.x; + tmp[1] = tmp_ph.y; + tmp[2] = tmp_ph.z; + tmp[3] = tmp_ph.w; + uint32_t* out_tmp = reinterpret_cast(&out[0]); + out_tmp[0] = tmp[start_idx]; + } + private: struct ull2 { diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 9a6865f15f..9389a5397f 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -10,7 +10,6 @@ #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" #include "ck_tile/ops/fmha/block/page_block_navigator.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" @@ -19,14 +18,10 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 7ebb306cce..e036402e16 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -286,11 +286,226 @@ struct BlockDropout }); } + ck_tile::philox ph; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +template +struct BlockDropoutBwd; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = false; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + template + __host__ __device__ static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = true; + // true: 32*32 warp gemm + // false: 16*16 warp gemm + static constexpr bool IsWG32 = IsWG32_; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_) + : ph(seed, + offset + (i_batch * nheads + i_head) * get_warp_size() + + (IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using BlockGemmShape = remove_cvref_t; + using WG = remove_cvref_t())>; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); + constexpr index_t kMPerStep = [&]() { + if constexpr(MBwdWG16MultiIterCheck) + { + return MWarp * WG::kM * 2; + } + else + { + return MWarp * WG::kM; + } + }(); + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16); + + constexpr index_t MIterPerWarp = [&]() { + if constexpr(MBwdWG16MultiIterCheck) + { + return 2; + } + else + { + return 1; + } + }(); + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + // except headdim256. + constexpr auto randval_block_inner_part_dstr_encoding = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(IsWG32) + return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + else + return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{}; + } + else + { + if constexpr(IsWG32) + return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + else + return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{}; + } + }(); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + template - CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_m0_idx, + const index_t start_n0_idx, PComputeWindow& p_compute, RandValDramWindow& randval_dram_window) const { @@ -305,30 +520,177 @@ struct BlockDropout constexpr index_t kMPerStep = MWarp * WG::kM; constexpr index_t kNPerStep = NWarp * WG::kN; - // register distribute - auto randval = - make_static_distributed_tensor(MakeRandValTileDistribution()); - static_assert(randval.kThreadElementSpaceSize == 16); + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); - const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - int block_row_start = (start_m0_idx / WG::kM) + i_m0; - int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + + auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; uint2 rowcol = make_uint2(block_row_start, block_col_start); // generate random number uint8_t random_uint8_t[16]; ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx1 = + tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + // save to Global + if constexpr(IsStoreRandval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16); + constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16); + constexpr index_t kMPerStep = [&]() { + if constexpr(MBwdWG16MultiIterCheck) + { + return MWarp * WG::kM * 2; + } + else + { + return MWarp * WG::kM; + } + }(); + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval = make_static_distributed_tensor( + MakeRandValTileDistribution()); + if constexpr(IsWG32) + static_assert(randval.kThreadElementSpaceSize == 16); + else + static_assert(randval.kThreadElementSpaceSize == 4 || + randval.kThreadElementSpaceSize == 8); + + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + int block_row_start, block_col_start; + if constexpr(IsWG32) + { + block_row_start = (start_m0_idx / WG::kM) + i_m0; + block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + } + else + { + block_row_start = start_m0_idx / 32 + i_m0; + block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2; + } + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t* random_uint8_t_; + if constexpr(MBwdWG16SingleIterCheck) + { + uint8_t random_uint8_t[4]; + // m0t0 ~m0t15/m0t32~m0t47: 0 + // m0t16~m0t31/m0t48~m0t63: 1 + // m1t0 ~m1t15/m1t32~m1t47: 2 + // m1t16~m1t31/m1t48~m1t63: 3 + const index_t start_idx = + ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1); + ph.get_random_4x8( + random_uint8_t, reinterpret_cast(rowcol), start_idx); + random_uint8_t_ = random_uint8_t; + } + else if constexpr(MBwdWG16MultiIterCheck) + { + uint8_t random_uint8_t[8]; + // t0 ~t15/t32~t47: 0 + // t16~t31/t48~t63: 1 + const index_t start_idx = (get_lane_id() >> 4) & 1; + ph.get_random_8x8( + random_uint8_t, reinterpret_cast(rowcol), start_idx); + random_uint8_t_ = random_uint8_t; + } + else + { + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, + reinterpret_cast(rowcol)); + random_uint8_t_ = random_uint8_t; + } + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); int i_random_idx = 0; sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - randval(r_idx) = random_uint8_t[i_random_idx++]; - constexpr auto p_idx0 = - tile_distributed_index{}; + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + randval(r_idx) = random_uint8_t_[i_random_idx++]; + constexpr auto p_idx0 = tile_distributed_index{}; constexpr auto p_idx1 = tile_distributed_index{}; constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t @@ -337,19 +699,19 @@ struct BlockDropout }); }); // save to Global - if(is_store_randval) + if constexpr(IsStoreRandval) { const auto randval_store = cast_tile(randval); store_tile(randval_dram_window, randval_store); move_tile_window(randval_dram_window, {kMPerStep, 0}); } }); - if(is_store_randval) + if constexpr(IsStoreRandval) { move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); } }); - if(is_store_randval) + if constexpr(IsStoreRandval) { move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); } @@ -358,7 +720,6 @@ struct BlockDropout ck_tile::philox ph; const float rp_undrop; const uint8_t p_undrop_in_uint8_t; - const bool is_store_randval; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index e713cefbda..167494b193 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -23,13 +23,9 @@ namespace ck_tile { -template +template struct FmhaBwdDQDKDVKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaPipeline = ck_tile::remove_cvref_t; using KGradEpiloguePipeline = ck_tile::remove_cvref_t; using VGradEpiloguePipeline = ck_tile::remove_cvref_t; @@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; - static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kHasMask = FmhaMask::IsMasking; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + static constexpr bool kHasDropout = FmhaDropout::IsDropout; + static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval; + static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic; // clang-format off template struct t2s; @@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel { // sync with generate.py // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using gbr = typename bfs::Gemm0BlockWarps; - using gwt = typename bfs::Gemm0WarpTile; + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr0 = typename bfs::Gemm0BlockWarps; + using gbr1 = typename bfs::Gemm1BlockWarps; + using gbr4 = typename bfs::Gemm4BlockWarps; + using gwt0 = typename bfs::Gemm0WarpTile; + using gwt1 = typename bfs::Gemm1WarpTile; #define _SS_ std::string #define _TS_ std::to_string auto pn = [&] () { @@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel return _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + - "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" + + _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + + "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel const void* lse_ptr; const void* do_ptr; const void* d_ptr; - void* dq_ptr; + void* dq_acc_ptr; void* dk_ptr; void* dv_ptr; @@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t num_head_q; ck_tile::index_t nhead_ratio_qk; float raw_scale; -#if CK_TILE_FMHA_FWD_FAST_EXP2 float scale; -#endif ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; @@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; - - ck_tile::index_t batch_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; }; struct FmhaBwdCommonBiasKargs @@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel float rp_undrop = 1; float scale_rp_undrop = 1; uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - bool is_store_randval = false; uint64_t drop_seed = 1; uint64_t drop_offset = 0; void* rand_val_ptr = nullptr; @@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel { ck_tile::index_t batch_stride_randval = 0; }; + struct FmhaBwdDeterministicKargs + { + ck_tile::index_t split_stride_dq_acc = 0; + }; struct FmhaBwdBatchModeKargs : FmhaBwdCommonKargs, @@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel FmhaBwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; }; @@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel FmhaBwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel const void* do_ptr, const void* d_ptr, void* rand_val_ptr, - void* dq_ptr, void* dk_ptr, void* dv_ptr, void* dbias_ptr, + void* dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, @@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, + ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, @@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, @@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, + ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - bool s_randval, const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, @@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel lse_ptr, do_ptr, d_ptr, - dq_ptr, + dq_acc_ptr, dk_ptr, dv_ptr, seqlen_q, @@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel num_head_q, nhead_ratio_qk, scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale * ck_tile::log2e_v<>), -#endif stride_q, stride_k, stride_v, stride_do, + stride_dq_acc, stride_dk, stride_dv, nhead_stride_q, @@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel nhead_stride_v, nhead_stride_do, nhead_stride_lsed, - batch_stride_lsed}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for dbias - {}, // placeholder for mask - {}, // placeholder for dropout + nhead_stride_dq_acc, + nhead_stride_dk, + nhead_stride_dv}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + {}, // placeholder for deterministic batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_do, + batch_stride_lsed, + batch_stride_dq_acc, batch_stride_dk, batch_stride_dv}; @@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel if constexpr(kHasDropout) { kargs.init_dropout(p_drop, drop_seed_offset, scale); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.batch_stride_randval = batch_stride_randval; - kargs.is_store_randval = s_randval; + if constexpr(kIsStoreRandval) + { + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + } + } + + if constexpr(kIsDeterministic) + { + kargs.split_stride_dq_acc = split_stride_dq_acc; } return kargs; @@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel const void* do_ptr, const void* d_ptr, void* rand_val_ptr, - void* dq_ptr, void* dk_ptr, void* dv_ptr, void* dbias_ptr, + void* dq_acc_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_k_ptr, @@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, + ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, @@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_lsed, + ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - bool s_randval, const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, @@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel lse_ptr, do_ptr, d_ptr, - dq_ptr, + dq_acc_ptr, dk_ptr, dv_ptr, -1, // seqlen will be updated by another pointer @@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel num_head_q, nhead_ratio_qk, scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale * ck_tile::log2e_v<>), -#endif stride_q, stride_k, stride_v, stride_do, + stride_dq_acc, stride_dk, stride_dv, nhead_stride_q, @@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel nhead_stride_v, nhead_stride_do, nhead_stride_lsed, - batch_stride_lsed}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for dbias - {}, // placeholder for mask - {}, // placeholder for dropout + nhead_stride_dq_acc, + nhead_stride_dk, + nhead_stride_dv}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + {}, // placeholder for deterministic reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel if constexpr(kHasDropout) { kargs.init_dropout(p_drop, drop_seed_offset, scale); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.is_store_randval = s_randval; + if constexpr(kIsStoreRandval) + { + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + } + } + if constexpr(kIsDeterministic) + { + kargs.split_stride_dq_acc = split_stride_dq_acc; } return kargs; @@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + return dim3( + ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex(); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); @@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel long_index_t batch_offset_randval = 0; long_index_t batch_offset_do = 0; long_index_t batch_offset_lsed = 0; + long_index_t batch_offset_dq_acc = 0; long_index_t batch_offset_dk = 0; long_index_t batch_offset_dv = 0; long_index_t batch_offset_dbias = 0; @@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_v = key_start * kargs.stride_v; - batch_offset_do = query_start * kargs.stride_do; - batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; - batch_offset_dk = key_start * kargs.stride_dk; - batch_offset_dv = key_start * kargs.stride_dv; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = query_start; + batch_offset_dq_acc = query_start * kargs.stride_dq_acc; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; @@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel { batch_offset_dbias = key_start; } - if constexpr(kHasDropout) + if constexpr(kIsStoreRandval) { batch_offset_randval = query_start * kargs.stride_randval; } @@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; - batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; - batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; - batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dq_acc = static_cast(i_batch) * kargs.batch_stride_dq_acc; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel { batch_offset_dbias = static_cast(i_batch) * kargs.batch_stride_dbias; } - if constexpr(kHasDropout) + if constexpr(kIsStoreRandval) { batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; @@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + static_cast(i_nhead) * kargs.nhead_stride_do + batch_offset_do; - QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_k + + static_cast(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk; VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_v + + static_cast(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv; // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window @@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel make_tuple(kargs.stride_q, 1), number{}, number<1>{}); - const auto q_dram = [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto qt_dram_naive = - transform_tensor_view(q_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_q), - make_pass_through_transform(kargs.seqlen_q)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - const auto qt_dram = [&]() { - if constexpr(FmhaPipeline::kQTLoadOnce) - { - return pad_tensor_view( - qt_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - qt_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); + const auto q_dram = pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel make_tuple(kargs.stride_k, 1), number{}, number<1>{}); - const auto k_dram = [&]() { - if constexpr(FmhaPipeline::kKLoadOnce) - { - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto kt_dram_naive = - transform_tensor_view(k_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_q), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - const auto kt_dram = [&]() { - if constexpr(FmhaPipeline::kKTLoadOnce) - { - return pad_tensor_view( - kt_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - kt_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); + const auto k_dram = pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel make_tuple(kargs.stride_v, 1), number{}, number<1>{}); - if constexpr(FmhaPipeline::kVLoadOnce) - { - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); }(); const auto lse_dram = [&]() { @@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel make_tuple(kargs.stride_do, 1), number{}, number<1>{}); - const auto do_dram = [&]() { - if constexpr(FmhaPipeline::kOGradLoadOnce) - { - return pad_tensor_view( - do_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - do_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto dot_dram_naive = - transform_tensor_view(do_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_q)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - const auto dot_dram = [&]() { - if constexpr(FmhaPipeline::kOGradTLoadOnce) - { - return pad_tensor_view( - dot_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - dot_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - auto dq_dram = [&]() { - const auto dq_dram_naive = make_naive_tensor_view( - dq_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - dq_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); + const auto do_dram = pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); auto q_dram_window = make_tile_window( q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), + make_tuple(number{}, number{}), {0, 0}); - auto qt_dram_window = - make_tile_window(qt_dram, - [&]() { - if constexpr(FmhaPipeline::kQTLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, - number{}); - }(), - {0, 0}); - auto k_dram_window = make_tile_window( k_dram, - [&]() { - if constexpr(FmhaPipeline::kKLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), + make_tuple(number{}, number{}), {i_n0, 0}); - auto kt_dram_window = - make_tile_window(kt_dram, - [&]() { - if constexpr(FmhaPipeline::kKTLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, - number{}); - }(), - {0, i_n0}); - auto v_dram_window = make_tile_window( v_dram, - [&]() { - if constexpr(FmhaPipeline::kVLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), + make_tuple(number{}, number{}), {i_n0, 0}); auto do_dram_window = make_tile_window( do_dram, - [&]() { - if constexpr(FmhaPipeline::kOGradLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), + make_tuple(number{}, number{}), {0, 0}); - auto dot_dram_window = - make_tile_window(dot_dram, - [&]() { - if constexpr(FmhaPipeline::kOGradTLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, - number{}); - }(), - {0, 0}); + auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { + if constexpr(kIsDeterministic) + { + AccDataType* dq_acc_ptr = + reinterpret_cast(kargs.dq_acc_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + static_cast(i_tile_n_) * kargs.split_stride_dq_acc + + batch_offset_dq_acc; - auto dq_dram_window = make_tile_window( - dq_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto dq_acc_dram = [&]() { + const auto dq_acc_dram_naive = + make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + return make_tile_window( + dq_acc_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + else + { + AccDataType* dq_acc_ptr = + reinterpret_cast(kargs.dq_acc_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + batch_offset_dq_acc; + + auto dq_acc_dram = [&]() { + const auto dq_acc_dram_naive = + make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + return make_tile_window( + dq_acc_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + }(); auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number{}), {0}); @@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel // TODO: how to use s_read? AccDataType slope = *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 slope *= ck_tile::log2e_v<>; -#endif if constexpr(kHasMask) { return make_alibi_from_lr_mask(slope, @@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel }(); // dropout - float rp_undrop = 1; - float scale_rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - uint64_t drop_seed = 0; - uint64_t drop_offset = 0; - bool is_store_randval = false; - + float rp_undrop = 1; + float scale_rp_undrop = 1; if constexpr(kHasDropout) { - rp_undrop = kargs.rp_undrop; - scale_rp_undrop = kargs.scale_rp_undrop; - p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; - drop_seed = kargs.drop_seed; - drop_offset = kargs.drop_offset; - is_store_randval = kargs.is_store_randval; + rp_undrop = kargs.rp_undrop; + scale_rp_undrop = kargs.scale_rp_undrop; } - BlockDropout dropout(i_batch, - i_nhead, - kargs.num_head_q, - drop_seed, - drop_offset, - rp_undrop, - p_undrop_in_uint8_t, - is_store_randval); + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return FmhaDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t}; + } + else + { + return FmhaDropout{}; + }; + }(); auto randval_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto randval_dram_window_lengths = make_tuple(number{}, number{}); - if constexpr(kHasDropout) + if constexpr(kIsStoreRandval) { RandValOutputDataType* rand_val_ptr = reinterpret_cast(kargs.rand_val_ptr) + @@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel }(); auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, - qt_dram_window, k_dram_window, - kt_dram_window, v_dram_window, bias_dram_window, randval_dram_window, do_dram_window, - dot_dram_window, lse_dram_window, d_dram_window, dq_dram_window, @@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel mask, position_encoding, kargs.raw_scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 kargs.scale, -#endif rp_undrop, scale_rp_undrop, smem_ptr, @@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel } }; -template +template struct FmhaBwdOGradDotOKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; @@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_d; - ck_tile::index_t batch_stride_d; }; struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs { ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_d; }; struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs @@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel stride_o, nhead_stride_do, nhead_stride_o, - nhead_stride_d, - batch_stride_d}, + nhead_stride_d}, batch_stride_do, - batch_stride_o}; + batch_stride_o, + batch_stride_d}; return kargs; } @@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_d, - ck_tile::index_t batch_stride_d) + ck_tile::index_t nhead_stride_d) { Kargs kargs{{o_ptr, do_ptr, @@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel stride_o, nhead_stride_do, nhead_stride_o, - nhead_stride_d, - batch_stride_d}, + nhead_stride_d}, reinterpret_cast(seqstart_q_ptr)}; return kargs; @@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // divide problem - const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); @@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel batch_offset_o = query_start * kargs.stride_o; batch_offset_do = query_start * kargs.stride_do; - batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + batch_offset_d = query_start; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel } }; +template +struct FmhaBwdConvertQGradKernel +{ + using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu; + static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0; + static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0; + static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim; + + using AccDataType = ck_tile::remove_cvref_t; + using QGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ; + static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ; + static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimQ) n += "d"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + // to avoid duplicated base class prblem, introduce an template arg + template + struct FmhaBwdConvertQGradEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdConvertQGradCommonKargs + { + const void* dq_acc_ptr; + void* dq_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dq_acc; + }; + + struct FmhaBwdConvertQGradDeterministicKargs + { + ck_tile::index_t split_stride_dq_acc = 0; + }; + + struct FmhaBwdConvertQGradBatchModeKargs + : FmhaBwdConvertQGradCommonKargs, + std::conditional_t> + { + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dq_acc; + }; + + struct FmhaBwdConvertQGradGroupModeKargs + : FmhaBwdConvertQGradCommonKargs, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* dq_acc_ptr, + void* dq_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t stride_dq, + ck_tile::index_t stride_dq_acc, + ck_tile::index_t nhead_stride_dq, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t batch_stride_dq, + ck_tile::index_t batch_stride_dq_acc, + ck_tile::index_t split_stride_dq_acc) + { + Kargs kargs{{dq_acc_ptr, + dq_ptr, + seqlen_q, + seqlen_k, + hdim_q, + stride_dq, + stride_dq_acc, + nhead_stride_dq, + nhead_stride_dq_acc}, + {}, + batch_stride_dq, + batch_stride_dq_acc}; + + if constexpr(kIsDeterministic) + { + kargs.split_stride_dq_acc = split_stride_dq_acc; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* dq_acc_ptr, + void* dq_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t stride_dq, + ck_tile::index_t stride_dq_acc, + ck_tile::index_t nhead_stride_dq, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t split_stride_dq_acc) + { + Kargs kargs{{dq_acc_ptr, + dq_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + stride_dq, + stride_dq_acc, + nhead_stride_dq, + nhead_stride_dq_acc}, + {}, + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr)}; + + if constexpr(kIsDeterministic) + { + kargs.split_stride_dq_acc = split_stride_dq_acc; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + + long_index_t batch_offset_dq = 0; + long_index_t batch_offset_dq_acc = 0; + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + batch_offset_dq = query_start * kargs.stride_dq; + batch_offset_dq_acc = query_start * kargs.stride_dq_acc; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if constexpr(kIsDeterministic) + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_dq = static_cast(i_batch) * kargs.batch_stride_dq; + batch_offset_dq_acc = static_cast(i_batch) * kargs.batch_stride_dq_acc; + } + + // for simplicity, batch stride we just modify the pointer + QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_dq + + batch_offset_dq; + + // dQAcc/dQ DRAM and DRAM window + const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() { + if constexpr(kIsDeterministic) + { + const AccDataType* dq_acc_ptr = + reinterpret_cast(kargs.dq_acc_ptr) + + static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + + batch_offset_dq_acc; + + const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); + + auto dq_acc_dram_naive = make_naive_tensor_view( + dq_acc_ptr, + make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + return pad_tensor_view(dq_acc_dram_naive, + make_tuple(number<1>{}, number{}, number{}), + sequence{}); + } + else + { + const AccDataType* dq_acc_ptr = + reinterpret_cast(kargs.dq_acc_ptr) + + static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + + batch_offset_dq_acc; + + auto dq_acc_dram_naive = make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + return pad_tensor_view(dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto dq_dram = [&]() { + auto dq_dram_naive = make_naive_tensor_view( + dq_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq, 1), + number{}, + number<1>{}); + return pad_tensor_view(dq_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dq_acc_dram_window = [&]() { + if constexpr(kIsDeterministic) + { + return make_tile_window( + dq_acc_dram, + make_tuple(number<1>{}, number{}, number{}), + {0, i_m0, 0}); + } + else + { + return make_tile_window( + dq_acc_dram, make_tuple(number{}, number{}), {i_m0, 0}); + } + }(); + + auto dq_dram_window = + make_tile_window(dq_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + if constexpr(kIsDeterministic) + { + const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); + FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits); + } + else + { + FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window); + } + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp deleted file mode 100644 index bc875b8e5a..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct FmhaBwdTilePartitioner -{ - using BlockFmhaShape = ck_tile::remove_cvref_t; - - static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; - - CK_TILE_HOST static constexpr auto - GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) - { - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); - } -}; - -template -struct FmhaBwdOGradDotOTilePartitioner -{ - CK_TILE_HOST static constexpr auto - GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) - { - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 5ecc3a4d80..49ef7bf6d9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -86,7 +86,7 @@ struct FmhaFwdKernel "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ @@ -387,7 +387,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_lse, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -448,7 +447,6 @@ struct FmhaFwdKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; - kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { @@ -524,7 +522,7 @@ struct FmhaFwdKernel } if constexpr(kStoreLSE) { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + batch_offset_lse = query_start; } if constexpr(kHasDropout) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 6f4313d5b6..e2c7db3e1b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + - (kStoreLSE ? "_lse" : "" ) + + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ @@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_o_acc; ck_tile::index_t split_stride_lse_acc; @@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel std::conditional_t> { ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_lse_acc; }; struct GroupModeKargs @@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_lse_acc, batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for lse {}, // placeholder for fp8_static_quant args - batch_stride_o}; + batch_stride_o, + batch_stride_lse_acc}; if constexpr(kStoreLSE) { @@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, - ck_tile::index_t batch_stride_lse, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc) { @@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_lse_acc, batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg @@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; - kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { @@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - const long_index_t batch_offset_lse_acc = - static_cast(i_batch) * kargs.batch_stride_lse_acc; const long_index_t batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } + long_index_t batch_offset_lse_acc = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_o = query_start * kargs.row_stride_o; + batch_offset_o = query_start * kargs.row_stride_o; + batch_offset_lse_acc = query_start; + + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel } else { - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } } // for simplicity, batch stride we just modify the pointer diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 71ddf7b1f5..04c85892ac 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -219,6 +219,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_lse_acc; }; using Kargs = std::conditional_t; @@ -296,8 +297,7 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc, - batch_stride_o_acc, + batch_stride_lse_acc batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -377,8 +377,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse_acc, - ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse_acc ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, @@ -478,12 +477,11 @@ struct FmhaFwdSplitKVKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - const long_index_t batch_offset_lse_acc = - static_cast(i_batch) * kargs.batch_stride_lse_acc; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse_acc = 0; const long_index_t batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; @@ -493,8 +491,9 @@ struct FmhaFwdSplitKVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_lse_acc = query_start; if constexpr(std::is_same_v) { batch_offset_v = key_start * kargs.stride_v; @@ -541,9 +540,10 @@ struct FmhaFwdSplitKVKernel } }(); - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; + batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp new file mode 100644 index 0000000000..3da1104169 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdConvertQGrad +{ + using AccDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN0 = Problem::kN0; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kQKHeaddim = Problem::kQKHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + + static constexpr index_t kAlignmentQGradAcc = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + // Convert only + template + CK_TILE_HOST_DEVICE void + operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp, + QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); + + auto dq_acc_dram_window = + make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), + dq_acc_dram_block_window_tmp.get_window_lengths(), + dq_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakePostQGradDramTileDistribution()); + + auto dq_acc = load_tile(dq_acc_dram_window); + const auto dq = cast_tile(dq_acc); + + store_tile(dq_dram_block_window_tmp, dq); + } + + // Reduce + Convert + template + CK_TILE_HOST_DEVICE void + operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp, + QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + index_t nsplits) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); + + auto dq_acc_dram_window = + make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(), + dq_acc_dram_block_window_tmp.get_window_lengths(), + dq_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakePostQGradAccDramTileDistribution()); + + auto dq_acc = decltype(load_tile(dq_acc_dram_window)){}; + clear_tile(dq_acc); + + constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans(); + index_t i_total_loops = 0; + auto dq_acc_buf = load_tile(dq_acc_dram_window); + move_tile_window(dq_acc_dram_window, {1, 0, 0}); + + do + { + sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) { + sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) { + constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2); + dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx); + }); + }); + }); + + dq_acc_buf = load_tile(dq_acc_dram_window); + move_tile_window(dq_acc_dram_window, {1, 0, 0}); + + i_total_loops += 1; + } while(i_total_loops < (nsplits - 1)); + + sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) { + sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) { + constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2); + dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx); + }); + }); + }); + + // declare dq + constexpr auto dq_converted_dstr = + Policy::template MakePostQGradAccDramTileDistribution(); + auto dq_converted = make_static_distributed_tensor(dq_converted_dstr); + + sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) { + sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) { + constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2); + dq_converted(n_i_j_idx) = type_convert(dq_acc[n_i_j_idx]); + }); + }); + }); + + constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution(); + auto dq = make_static_distributed_tensor(dq_dstr); + dq.get_thread_buffer() = dq_converted.get_thread_buffer(); + + store_tile(dq_dram_block_window_tmp, dq); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp index f189937038..c38779d1d2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -4,11 +4,11 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" namespace ck_tile { -template +template struct BlockFmhaBwdOGradDotO { using ODataType = remove_cvref_t; @@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO static constexpr index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp deleted file mode 100644 index 7843ab33a1..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" - -namespace ck_tile { - -// These templates are not used here. -using BlockFmhaBwdOGradDotODefaultPolicy = - BlockFmhaBwdPipelineDefaultPolicy; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp new file mode 100644 index 0000000000..131729992b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -0,0 +1,782 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKRKTRVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; + using HotLoopScheduler = typename Policy::template HotLoopScheduler; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = 1; + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "kr_ktr_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return make_tuple(dk_acc, dv_acc); + } + } + KDataType* k_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegSliceBlockDescriptor()); + + auto k_reg_tensor = make_static_distributed_tensor( + Policy::template MakeKRegBlockDescriptor()); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + + VDataType* v_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegSliceBlockDescriptor()); + + auto v_reg_tensor = make_static_distributed_tensor( + Policy::template MakeVRegBlockDescriptor()); + + //------------------------------------------------------------------ + // KT, Reg ->LDS ->Reg + auto shuffled_k_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledKRegWriteBlockDescriptor()); + + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + + auto shuffled_k_lds_write = make_tensor_view( + kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); + + auto shuffled_k_lds_write_window = make_tile_window( + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto kt_lds_read = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); + + auto kt_lds_read_window = + make_tile_window(kt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_block_tile = load_tile(k_dram_window); + auto v_block_tile = load_tile(v_dram_window); + + store_tile(k_lds_write_window, k_block_tile); + shuffle_tile(shuffled_k_block_tile, k_block_tile); + store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile); + + block_sync_lds(); + k_reg_tensor = load_tile(k_lds_read_window); + block_sync_lds(); + + auto kt_reg_tensor = load_tile(kt_lds_read_window); + + store_tile(v_lds_write_window, v_block_tile); + + block_sync_lds(); + + v_reg_tensor = load_tile(v_lds_read_window); + block_sync_lds(); + //---------------------------- Loop Load in ----------------------------// + // Q: HBM ->Reg ->LDS + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT())); + + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_lds_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + // QT: Reg -> Reg-> LDS + auto shuffled_q_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledQRegWriteBlockDescriptor()); + + QDataType* qt_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto shuffled_q_lds_write = make_tensor_view( + qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); + + auto shuffled_q_lds_write_window = make_tile_window( + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto qt_lds_read = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); + + auto qt_lds_read_window = + make_tile_window(qt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->Reg ->LDS + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT())); + + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read_window = + make_tile_window(do_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + do_lds_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + // dOT: Reg ->Reg ->LDS + auto shuffled_do_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledOGradRegWriteBlockDescriptor()); + + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad())); + + auto shuffled_do_lds_write = make_tensor_view( + dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); + + auto shuffled_do_lds_write_window = make_tile_window( + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto dot_read_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); + + auto dot_lds_read_window = + make_tile_window(dot_read_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + auto ds_lds_read_window = + make_tile_window(ds_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + ds_lds_window.get_window_origin(), + Policy::template MakeSGradRegSliceBlockDescriptor()); + + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + BiasDataType* bias_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); + + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_s_lds_read_window = + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + LSEDataType* lse_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ())); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + DDataType* d_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE())); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + __builtin_amdgcn_sched_barrier(0); + // Hot loop + while(i_total_loops < num_total_loop) + { + auto q_block_tile = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + + auto lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + store_tile(q_lds_window, q_block_tile); + shuffle_tile(shuffled_q_block_tile, q_block_tile); + store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile); + + store_tile(lse_lds_write_window, lse_block_tile); + + block_sync_lds(); + + auto q_reg_tensor = load_tile(q_lds_read_window); + auto lse = load_tile(lse_lds_read_window); + + block_sync_lds(); + + // STAGE 1, Q@K Gemm0 + auto s_acc = SPBlockTileType{}; + + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto p = SPBlockTileType{}; + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + } + else + { + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + } + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto do_block_tile = load_tile(do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + + auto d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + + store_tile(do_lds_window, do_block_tile); + shuffle_tile(shuffled_do_block_tile, do_block_tile); + store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile); + + store_tile(d_lds_write_window, d_block_tile); + + block_sync_lds(); + + auto dot_reg_tensor = load_tile(dot_lds_read_window); + + block_sync_lds(); + + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + + // STAGE 4, OGrad@V Gemm2 + auto do_reg_tensor = load_tile(do_lds_read_window); + auto d = load_tile(d_lds_read_window); + block_sync_lds(); + + auto dp_acc = SPGradBlockTileType{}; + + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + // STAGE 5, P^T(PGrad^T - D) + auto ds = SPGradBlockTileType{}; + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_reg_tensor = load_tile(qt_lds_read_window); + block_sync_lds(); + + const auto ds_gemm = cast_tile(ds); + + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + + block_sync_lds(); + + auto ds_reg_tensor = load_tile(ds_lds_read_window); + auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; + move_tile_window(ds_lds_read_window, {0, kK4}); + + // STAGE7 SGrad@K^T Gemm4 + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {0, kK4}); + } + auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {0, -kN0}); + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + + i_total_loops += 1; + seqlen_q_step += kM0; + } + + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + return make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp new file mode 100644 index 0000000000..9e6a2725c9 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -0,0 +1,1037 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; + using HotLoopScheduler = typename Policy::template HotLoopScheduler; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = 1; + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "kr_ktr_vr_iglp"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return make_tuple(dk_acc, dv_acc); + } + } + KDataType* k_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegSliceBlockDescriptor()); + + auto k_reg_tensor = make_static_distributed_tensor( + Policy::template MakeKRegBlockDescriptor()); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + + VDataType* v_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_write_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegSliceBlockDescriptor()); + + auto v_reg_tensor = make_static_distributed_tensor( + Policy::template MakeVRegBlockDescriptor()); + + //------------------------------------------------------------------ + // KT, Reg ->LDS ->Reg + auto shuffled_k_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledKRegWriteBlockDescriptor()); + + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + + auto shuffled_k_lds_write = make_tensor_view( + kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); + + auto shuffled_k_lds_write_window = make_tile_window( + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto kt_lds_read = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); + + auto kt_lds_read_window = + make_tile_window(kt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_block_tile = load_tile(k_dram_window); + auto v_block_tile = load_tile(v_dram_window); + + store_tile(k_lds_write_window, k_block_tile); + shuffle_tile(shuffled_k_block_tile, k_block_tile); + store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile); + + block_sync_lds(); + k_reg_tensor = load_tile(k_lds_read_window); + block_sync_lds(); + + auto kt_reg_tensor = load_tile(kt_lds_read_window); + + store_tile(v_lds_write_window, v_block_tile); + + block_sync_lds(); + + v_reg_tensor = load_tile(v_lds_read_window); + //---------------------------- Loop Load in ----------------------------// + // Q: HBM ->Reg ->LDS + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT())); + + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_lds_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + // QT: Reg -> Reg-> LDS + auto shuffled_q_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledQRegWriteBlockDescriptor()); + + QDataType* qt_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto shuffled_q_lds_write = make_tensor_view( + qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); + + auto shuffled_q_lds_write_window = make_tile_window( + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto qt_lds_read = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); + + auto qt_lds_read_window = + make_tile_window(qt_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->Reg ->LDS + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT())); + + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read_window = + make_tile_window(do_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + do_lds_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + // dOT: Reg ->Reg ->LDS + auto shuffled_do_block_tile = make_static_distributed_tensor( + Policy::template MakeShuffledOGradRegWriteBlockDescriptor()); + + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad())); + + auto shuffled_do_lds_write = make_tensor_view( + dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); + + auto shuffled_do_lds_write_window = make_tile_window( + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + + auto dot_read_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); + + auto dot_lds_read_window = + make_tile_window(dot_read_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + auto ds_lds_read_window = + make_tile_window(ds_lds_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + ds_lds_window.get_window_origin(), + Policy::template MakeSGradRegSliceBlockDescriptor()); + + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + BiasDataType* bias_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); + + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_s_lds_read_window = + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + LSEDataType* lse_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ())); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + DDataType* d_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE())); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + using SPBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + /* + * Prefetch Q, LSE, dO, D + */ + auto q_block_tile = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + auto lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + auto do_block_tile = load_tile(do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + + auto d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + + /* + * Store prefetched data into LDS + */ + block_sync_lds(); + store_tile(q_lds_window, q_block_tile); + shuffle_tile(shuffled_q_block_tile, q_block_tile); + store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile); + + store_tile(lse_lds_write_window, lse_block_tile); + + store_tile(do_lds_window, do_block_tile); + shuffle_tile(shuffled_do_block_tile, do_block_tile); + store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile); + + store_tile(d_lds_write_window, d_block_tile); + block_sync_lds(); + + /* + * Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline + */ + + auto q_reg_tensor = load_tile(q_lds_read_window); + auto lse = load_tile(lse_lds_read_window); + auto do_reg_tensor = load_tile(do_lds_read_window); + auto d = load_tile(d_lds_read_window); + + clear_tile(dv_acc); + clear_tile(dk_acc); + + __builtin_amdgcn_sched_barrier(0); + // Hot loop + while(i_total_loops < (num_total_loop - 1)) + { + // STAGE 1, Q@K Gemm0 + auto s_acc = SPBlockTileType{}; + + q_block_tile = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + do_block_tile = load_tile(do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + auto dot_reg_tensor = load_tile(dot_lds_read_window); + + HotLoopScheduler::template GemmStagedScheduler<0>(); + __builtin_amdgcn_sched_barrier(0); + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto p = SPBlockTileType{}; + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + } + else + { + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + } + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 3, P^T@OGrad^T Gemm1 + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + + auto qt_reg_tensor = load_tile(qt_lds_read_window); + + HotLoopScheduler::template GemmStagedScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + // STAGE 4, OGrad@V Gemm2 + auto dp_acc = SPGradBlockTileType{}; + + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + shuffle_tile(shuffled_q_block_tile, q_block_tile); + store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile); + + store_tile(lse_lds_write_window, lse_block_tile); + + store_tile(do_lds_window, do_block_tile); + shuffle_tile(shuffled_do_block_tile, do_block_tile); + store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile); + + store_tile(d_lds_write_window, d_block_tile); + + HotLoopScheduler::template GemmStagedScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + // STAGE 5, P^T(PGrad^T - D) + auto ds = SPGradBlockTileType{}; + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + + block_sync_lds(); + + auto ds_reg_tensor = load_tile(ds_lds_read_window); + auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; + move_tile_window(ds_lds_read_window, {0, kK4}); + q_reg_tensor = load_tile(q_lds_read_window); + lse = load_tile(lse_lds_read_window); + + HotLoopScheduler::template GemmStagedScheduler<3>(); + __builtin_amdgcn_sched_barrier(0); + // STAGE7 SGrad@K^T Gemm4 + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {0, kK4}); + } + auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {0, -kN0}); + + do_reg_tensor = load_tile(do_lds_read_window); + d = load_tile(d_lds_read_window); + + HotLoopScheduler::template GemmStagedScheduler<4>(); + + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + + i_total_loops += 1; + seqlen_q_step += kM0; + } + __builtin_amdgcn_sched_barrier(0); + + // Tail + auto s_acc = SPBlockTileType{}; + + // STAGE 1, Q@K Gemm0 + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto p = SPBlockTileType{}; + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + } + else + { + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + } + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto p_gemm = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, p); + } + else + { + return cast_tile(p); + } + }(); + + Policy::template PTFromGemm0CToGemm1A( + pt_reg_tensor, p_gemm); + auto dot_reg_tensor = load_tile(dot_lds_read_window); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + + HotLoopScheduler::template GemmStagedScheduler<1>(); + + // STAGE 4, OGrad@V Gemm2 + auto dp_acc = SPGradBlockTileType{}; + + auto qt_reg_tensor = load_tile(qt_lds_read_window); + + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + HotLoopScheduler::template GemmStagedScheduler<2>(); + + // STAGE 5, P^T(PGrad^T - D) + auto ds = SPGradBlockTileType{}; + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + store_tile(ds_lds_window, ds_gemm); + + block_sync_lds(); + + auto ds_reg_tensor = load_tile(ds_lds_read_window); + auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; + move_tile_window(ds_lds_read_window, {0, kK4}); + + HotLoopScheduler::template GemmStagedScheduler<3>(); + // STAGE 7, SGrad@K^T Gemm4 + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {0, kK4}); + } + auto kt_reg_tensor_slice = get_slice_tile( + kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence{}); + + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + + HotLoopScheduler::template GemmStagedScheduler<4>(); + + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + + return make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp deleted file mode 100644 index 3444567508..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp +++ /dev/null @@ -1,848 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -template -struct BlockFmhaBwdDQDKDVPipelineKSKTSVR -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using GemmDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using DDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using OGradDataType = remove_cvref_t; - using QGradDataType = remove_cvref_t; - using KGradDataType = remove_cvref_t; - using VGradDataType = remove_cvref_t; - using BiasGradDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - - static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; - - static constexpr bool kQLoadOnce = false; - static constexpr bool kQTLoadOnce = false; - static constexpr bool kKLoadOnce = true; - static constexpr bool kKTLoadOnce = true; - static constexpr bool kVLoadOnce = true; - static constexpr bool kOGradLoadOnce = false; - static constexpr bool kOGradTLoadOnce = false; - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; - static constexpr bool kHasDropout = Problem::kHasDropout; - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); - static constexpr index_t kAlignmentQGrad = - kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); - static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); - static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); - - static constexpr const char* name = "ks_kts_vr"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - const QTDramBlockWindowTmp& qt_dram_block_window_tmp, - const KDramBlockWindowTmp& k_dram_block_window_tmp, - const KTDramBlockWindowTmp& kt_dram_block_window_tmp, - const VDramBlockWindowTmp& v_dram_block_window_tmp, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, - const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - const OGradDramBlockWindowTmp& do_dram_block_window_tmp, - const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, - const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, - const DDramBlockWindowTmp& d_dram_block_window_tmp, - const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, - const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, - FmhaMask mask, - PositionEncoding position_encoding, - float raw_scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - float scale, -#endif - float rp_undrop, - float scale_rp_undrop, - void* smem_ptr, - BlockDropout& dropout) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kVHeaddim == - OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // Q tile in LDS - QDataType* q_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); - - // QT tile in LDS - QDataType* qt_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto qt_lds = make_tensor_view( - qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); - auto qt_lds_window = - make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); - - // K tile in LDS - auto k_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - - // KT tile in LDS - KDataType* kt_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto kt_lds = make_tensor_view( - kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor()); - auto kt_lds_window = - make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGrad tile in LDS - OGradDataType* do_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto do_lds = make_tensor_view( - do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); - auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGradT tile in LDS - OGradDataType* dot_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto dot_lds = make_tensor_view( - dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); - auto dot_lds_window = - make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); - - // SGrad tile in LDS - GemmDataType* ds_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto ds_lds = make_tensor_view( - ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); - auto ds_lds_window = - make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); - - // BiasT/BiasGradT tile in LDS, use the same size and layout - BiasDataType* biast_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeKT())); - auto biast_lds = make_tensor_view( - biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); - auto biast_lds_shuffle_window = - make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); - auto dbiast_lds_shuffle_window = - make_tile_window(biast_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeShuffledBiasTileDistribution()); - - static_assert(std::is_same_v, - "BiasDataType and BiasGradDataType should be the same!"); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); - constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); - constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); - constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); - - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - v_dram_block_window_tmp.get_window_origin(), - Policy::template MakeVInRegDramTileDistribution()); - - auto v = load_tile(v_dram_window); // persistent V register tile - - using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); - using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); - using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); - - // init VGrad & KGrad - auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; - auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; - - clear_tile(dv_acc); - clear_tile(dk_acc); - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - k_dram_block_window_tmp.get_window_origin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - __builtin_amdgcn_sched_barrier(0); - const auto k_origin = k_dram_window.get_window_origin(); - const auto [seqlen_q_start, seqlen_q_end] = - mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) - { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return ck_tile::make_tuple(dk_acc, dv_acc); - } - } - - auto k_block_tile = load_tile(k_dram_window); - - store_tile(k_lds_window, k_block_tile); // // persistent K in LDS - - auto kt_dram_block_window = kt_dram_block_window_tmp; - - auto kt_dram_window = make_tile_window( - kt_dram_block_window.get_bottom_tensor_view(), - kt_dram_block_window.get_window_lengths(), - kt_dram_block_window.get_window_origin(), - Policy::template MakeKTDramTileDistribution()); // K^T DRAM tile window for - // load - - auto kt_block_tile = load_tile(kt_dram_window); - - auto kt_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledKTRegBlockDescriptor()); - shuffle_tile(kt_shuffle_tmp, kt_block_tile); - - store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS - - auto q_dram_block_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto qt_dram_block_window = - make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), - qt_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_q_start}); - - auto do_dram_block_window = - make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), - do_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto dot_dram_block_window = - make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), - dot_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_q_start}); - - auto dq_dram_block_window = - make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto lse_dram_block_window = - make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - auto d_dram_block_window = - make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_block_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N - - const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); - auto dbias_dram_block_window = - make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), - dbias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N - - auto qt_dram_window = - make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), - qt_dram_block_window.get_window_lengths(), - qt_dram_block_window.get_window_origin(), - Policy::template MakeQTDramTileDistribution()); - - auto dot_dram_window = - make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), - dot_dram_block_window.get_window_lengths(), - dot_dram_block_window.get_window_origin(), - Policy::template MakeOGradTDramTileDistribution()); - - auto lse_dram_window = make_tile_window( - lse_dram_block_window.get_bottom_tensor_view(), - lse_dram_block_window.get_window_lengths(), - lse_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto d_dram_window = make_tile_window( - d_dram_block_window.get_bottom_tensor_view(), - d_dram_block_window.get_window_lengths(), - d_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto bias_dram_window = - make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), - bias_dram_block_window.get_window_lengths(), - bias_dram_block_window.get_window_origin(), - Policy::template MakeBiasTileDistribution()); - - auto biast_lds_window = - make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), - biast_lds_shuffle_window.get_window_lengths(), - biast_lds_shuffle_window.get_window_origin(), - Policy::template MakeBiasTTileDistribution()); - - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_q_start); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kM0 / kK1; - constexpr index_t k2_loops = kVHeaddim / kK2; - constexpr index_t k3_loops = kM0 / kK3; - constexpr index_t k4_loops = kN0 / kK4; - do - { - auto q_dram_window = make_tile_window( - q_dram_block_window.get_bottom_tensor_view(), - q_dram_block_window.get_window_lengths(), - q_dram_block_window.get_window_origin(), - Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for - // load - - auto do_dram_window = make_tile_window( - do_dram_block_window.get_bottom_tensor_view(), - do_dram_block_window.get_window_lengths(), - do_dram_block_window.get_window_origin(), - Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile - // window for load - - // STAGE 1, Q@K Gemm0 - auto st_acc = SPTBlockTileType{}; - - auto q_block_tile = load_tile(q_dram_window); - { - move_tile_window(q_dram_window, {0, kK0}); - - clear_tile(st_acc); // Initialize S^T - - store_tile(q_lds_window, q_block_tile); // LDS write 0 - q_block_tile = load_tile(q_dram_window); // global read 1 - } - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - - if constexpr(k0_loops > 2) - { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { - block_sync_lds(); - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, i_k0 * kK0>{}, - sequence{})); - block_sync_lds(); - move_tile_window(q_dram_window, {0, kK0}); - - store_tile(q_lds_window, - q_block_tile); // LDS write i + 1 - q_block_tile = load_tile(q_dram_window); // global read i + 2 - }); - } - - const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile - { // tail - block_sync_lds(); - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, (k0_loops - 2) * kK0>{}, - sequence{})); - block_sync_lds(); - - store_tile(q_lds_window, q_block_tile); - block_sync_lds(); - - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{})); - } - - // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - block_sync_lds(); - auto bias_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBiasTileDistribution()); - shuffle_tile(bias_shuffle_tmp, bias_tile); - store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); - block_sync_lds(); - auto biast_tile = load_tile(biast_lds_window); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = raw_scale * x + type_convert(y); -#else - x = scale * x + log2e_v * type_convert(y); -#endif - }, - st_acc, - biast_tile); - move_tile_window(bias_dram_window, {kM0, 0}); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); - sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - st_acc(i_j_idx) *= raw_scale; -#else - st_acc(i_j_idx) *= scale; -#endif - position_encoding.update(st_acc(i_j_idx), row, col); - }); - }); - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); -#endif - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - if(need_perpixel_check) - { - set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - const auto lse = load_tile(lse_dram_window); - - static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_lse == -numeric::infinity() - ? type_convert(0.f) - : raw_lse; - } - else - { - return raw_lse; - } - }; - - auto pt = SPTBlockTileType{}; - constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); - sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); -#endif - sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); - } - else - { - pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); - } -#else - pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); -#endif - }); - }); - - auto dot_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledOGradTRegBlockDescriptor()); - block_sync_lds(); - { - shuffle_tile(dot_shuffle_tmp, dot_prefetch); - store_tile(dot_lds_window, - dot_shuffle_tmp); // store the prefetch - } - move_tile_window(dot_dram_window, {0, kK1}); - - if constexpr(kHasDropout) - { - dropout.Run( - seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); - } - - // STAGE 3, P^T@OGrad^T Gemm1 - const auto pt_gemm = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, - pt); - } - else - { - return cast_tile(pt); - } - }(); - - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto dot = load_tile(dot_dram_window); // load next OGrad^T - block_sync_lds(); - gemm_1(dv_acc, - get_slice_tile(pt_gemm, - sequence{}, - sequence<(i_k1 + 1) * kK1, kN0>{}), - dot_lds_window); - block_sync_lds(); - shuffle_tile(dot_shuffle_tmp, dot); - store_tile(dot_lds_window, - dot_shuffle_tmp); // store the prefetch - - move_tile_window(dot_dram_window, {0, kK1}); - }); - } - auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile - // tail - { - block_sync_lds(); - gemm_1(dv_acc, - get_slice_tile( - pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), - dot_lds_window); - block_sync_lds(); - } - - // STAGE 4, OGrad@V Gemm2 - auto dpt_acc = SPGradTBlockTileType{}; - - { - move_tile_window(do_dram_window, {0, kK2}); - - clear_tile(dpt_acc); // Initialize PGrad^T - - store_tile(do_lds_window, do_block_tile); // LDS write 0 - do_block_tile = load_tile(do_dram_window); // global read 1 - } - - if constexpr(k2_loops > 2) - { - static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { - block_sync_lds(); - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile( - v, sequence<0, i_k2 * kK2>{}, sequence{})); - block_sync_lds(); - move_tile_window(do_dram_window, {0, kK2}); - - store_tile(do_lds_window, - do_block_tile); // LDS write i + 1 - do_block_tile = load_tile(do_dram_window); // global read i + 2 - }); - } - - const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile - { // tail - block_sync_lds(); - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile(v, - sequence<0, (k2_loops - 2) * kK2>{}, - sequence{})); - block_sync_lds(); - - store_tile(do_lds_window, do_block_tile); - block_sync_lds(); - - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile(v, - sequence<0, (k2_loops - 1) * kK2>{}, - sequence{})); - } - - // STAGE 5, P^T(PGrad^T - D) - const auto d = load_tile(d_dram_window); - - auto dst = SPGradTBlockTileType{}; - constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); - sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - bool undrop_flag = pt[i_j_idx] >= 0; - dst(i_j_idx) = - pt[i_j_idx] * - (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); - }); - }); - - if constexpr(kHasBiasGrad) - { - const auto dbiast = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [&rp_undrop](const auto& x) { - return type_convert(x * rp_undrop); - }, - dst); - } - else - { - return cast_tile(dst); - } - }(); - store_tile(biast_lds_shuffle_window, dbiast); - block_sync_lds(); - auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); - auto dbiast_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeBiasTileDistribution()); - shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); - store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); - move_tile_window(dbias_dram_block_window, {kM0, 0}); - } - - // STAGE 6, SGrad^T@Q^T Gemm3 - auto qt_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledQTRegBlockDescriptor()); - block_sync_lds(); - { - shuffle_tile(qt_shuffle_tmp, qt_prefetch); - store_tile(qt_lds_window, - qt_shuffle_tmp); // store the prefetch - } - move_tile_window(qt_dram_window, {0, kK3}); - - const auto dst_gemm = cast_tile(dst); - - if constexpr(k3_loops > 1) - { - static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { - const auto qt = load_tile(qt_dram_window); // load next Q^T - block_sync_lds(); - gemm_3(dk_acc, - get_slice_tile(dst_gemm, - sequence{}, - sequence<(i_k3 + 1) * kK3, kN0>{}), - qt_lds_window); - block_sync_lds(); - shuffle_tile(qt_shuffle_tmp, qt); - store_tile(qt_lds_window, - qt_shuffle_tmp); // store the prefetch - - move_tile_window(qt_dram_window, {0, kK3}); - }); - } - // tail - { - block_sync_lds(); - gemm_3(dk_acc, - get_slice_tile( - dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), - qt_lds_window); - block_sync_lds(); - } - - // STAGE 7, SGrad@K^T Gemm4 - store_tile(ds_lds_window, dst_gemm); - - auto dq_acc = QGradBlockTileType{}; - clear_tile(dq_acc); // Initialize QGrad - - block_sync_lds(); - - static_for<0, k4_loops, 1>{}([&](auto i_k4) { - gemm_4(dq_acc, - get_slice_tile(ds_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{}), - get_slice_tile(kt_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{})); - }); - - // QGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dq_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); - } - const auto dq = cast_tile(dq_acc); - update_tile(dq_dram_block_window, dq); - - // move tile windows - move_tile_window(q_dram_block_window, {kM0, 0}); - move_tile_window(dq_dram_block_window, {kM0, 0}); - move_tile_window(do_dram_block_window, {kM0, 0}); - move_tile_window(lse_dram_window, {kM0}); - move_tile_window(d_dram_window, {kM0}); - } while(++i_total_loops < num_total_loop); - - // KGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dk_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); - } - // VGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); - } - - return ck_tile::make_tuple(dk_acc, dv_acc); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp deleted file mode 100644 index a05fbf252f..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" - -namespace ck_tile { - -// This pipeline is v located in regs, k & k^t located in lds. -using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy = - BlockFmhaBwdPipelineDefaultPolicy; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp deleted file mode 100644 index dec421c1e6..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp +++ /dev/null @@ -1,821 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -template -struct BlockFmhaBwdDQDKDVPipelineKSVR -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using GemmDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using DDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using OGradDataType = remove_cvref_t; - using QGradDataType = remove_cvref_t; - using KGradDataType = remove_cvref_t; - using VGradDataType = remove_cvref_t; - using BiasGradDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - - static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; - - static constexpr bool kQLoadOnce = false; - static constexpr bool kQTLoadOnce = false; - static constexpr bool kKLoadOnce = true; - static constexpr bool kKTLoadOnce = false; - static constexpr bool kVLoadOnce = true; - static constexpr bool kOGradLoadOnce = false; - static constexpr bool kOGradTLoadOnce = false; - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; - static constexpr bool kHasDropout = Problem::kHasDropout; - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); - static constexpr index_t kAlignmentQGrad = - kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); - static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); - static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); - - static constexpr const char* name = "ks_vr"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - const QTDramBlockWindowTmp& qt_dram_block_window_tmp, - const KDramBlockWindowTmp& k_dram_block_window_tmp, - const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, - const VDramBlockWindowTmp& v_dram_block_window_tmp, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, - const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - const OGradDramBlockWindowTmp& do_dram_block_window_tmp, - const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, - const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, - const DDramBlockWindowTmp& d_dram_block_window_tmp, - const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, - const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, - FmhaMask mask, - PositionEncoding position_encoding, - float raw_scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - float scale, -#endif - float rp_undrop, - float scale_rp_undrop, - void* smem_ptr, - BlockDropout& dropout) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kVHeaddim == - OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // Q tile in LDS - QDataType* q_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); - - // QT tile in LDS - QDataType* qt_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto qt_lds = make_tensor_view( - qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); - auto qt_lds_window = - make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); - - // K tile in LDS - auto k_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - - // KT tile in LDS - auto kt_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptorAsKT()); - auto kt_lds_window = - make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGrad tile in LDS - OGradDataType* do_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto do_lds = make_tensor_view( - do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); - auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGradT tile in LDS - OGradDataType* dot_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto dot_lds = make_tensor_view( - dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); - auto dot_lds_window = - make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); - - // SGrad tile in LDS - GemmDataType* ds_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto ds_lds = make_tensor_view( - ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); - auto ds_lds_window = - make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); - - // BiasT/BiasGradT tile in LDS, use the same size and layout - BiasDataType* biast_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto biast_lds = make_tensor_view( - biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); - auto biast_lds_shuffle_window = - make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); - auto dbiast_lds_shuffle_window = - make_tile_window(biast_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeShuffledBiasTileDistribution()); - - static_assert(std::is_same_v, - "BiasDataType and BiasGradDataType should be the same!"); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); - constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); - constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); - constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); - - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - v_dram_block_window_tmp.get_window_origin(), - Policy::template MakeVInRegDramTileDistribution()); - - auto v = load_tile(v_dram_window); // persistent V register tile - - using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); - using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); - using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); - - // init VGrad & KGrad - auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; - auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; - - clear_tile(dv_acc); - clear_tile(dk_acc); - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - k_dram_block_window_tmp.get_window_origin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - __builtin_amdgcn_sched_barrier(0); - const auto k_origin = k_dram_window.get_window_origin(); - const auto [seqlen_q_start, seqlen_q_end] = - mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) - { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return ck_tile::make_tuple(dk_acc, dv_acc); - } - } - - auto k_block_tile = load_tile(k_dram_window); - - store_tile(k_lds_window, k_block_tile); // // persistent K in LDS - - auto q_dram_block_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto qt_dram_block_window = - make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), - qt_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_q_start}); - - auto do_dram_block_window = - make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), - do_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto dot_dram_block_window = - make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), - dot_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_q_start}); - - auto dq_dram_block_window = - make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto lse_dram_block_window = - make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - auto d_dram_block_window = - make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_block_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N - - const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); - auto dbias_dram_block_window = - make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), - dbias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N - - auto qt_dram_window = - make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), - qt_dram_block_window.get_window_lengths(), - qt_dram_block_window.get_window_origin(), - Policy::template MakeQTDramTileDistribution()); - - auto dot_dram_window = - make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), - dot_dram_block_window.get_window_lengths(), - dot_dram_block_window.get_window_origin(), - Policy::template MakeOGradTDramTileDistribution()); - - auto lse_dram_window = make_tile_window( - lse_dram_block_window.get_bottom_tensor_view(), - lse_dram_block_window.get_window_lengths(), - lse_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto d_dram_window = make_tile_window( - d_dram_block_window.get_bottom_tensor_view(), - d_dram_block_window.get_window_lengths(), - d_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto bias_dram_window = - make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), - bias_dram_block_window.get_window_lengths(), - bias_dram_block_window.get_window_origin(), - Policy::template MakeBiasTileDistribution()); - - auto biast_lds_window = - make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), - biast_lds_shuffle_window.get_window_lengths(), - biast_lds_shuffle_window.get_window_origin(), - Policy::template MakeBiasTTileDistribution()); - - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_q_start); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kM0 / kK1; - constexpr index_t k2_loops = kVHeaddim / kK2; - constexpr index_t k3_loops = kM0 / kK3; - constexpr index_t k4_loops = kN0 / kK4; - do - { - auto q_dram_window = make_tile_window( - q_dram_block_window.get_bottom_tensor_view(), - q_dram_block_window.get_window_lengths(), - q_dram_block_window.get_window_origin(), - Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for - // load - - auto do_dram_window = make_tile_window( - do_dram_block_window.get_bottom_tensor_view(), - do_dram_block_window.get_window_lengths(), - do_dram_block_window.get_window_origin(), - Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile - // window for load - - // STAGE 1, Q@K Gemm0 - auto st_acc = SPTBlockTileType{}; - - auto q_block_tile = load_tile(q_dram_window); - { - move_tile_window(q_dram_window, {0, kK0}); - - clear_tile(st_acc); // Initialize S^T - - store_tile(q_lds_window, q_block_tile); // LDS write 0 - q_block_tile = load_tile(q_dram_window); // global read 1 - } - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - - if constexpr(k0_loops > 2) - { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { - block_sync_lds(); - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, i_k0 * kK0>{}, - sequence{})); - block_sync_lds(); - move_tile_window(q_dram_window, {0, kK0}); - - store_tile(q_lds_window, - q_block_tile); // LDS write i + 1 - q_block_tile = load_tile(q_dram_window); // global read i + 2 - }); - } - - const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile - { // tail - block_sync_lds(); - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, (k0_loops - 2) * kK0>{}, - sequence{})); - block_sync_lds(); - - store_tile(q_lds_window, q_block_tile); - block_sync_lds(); - - gemm_0(st_acc, - q_lds_window, - get_slice_tile(k_lds_window, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{})); - } - - // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - block_sync_lds(); - auto bias_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBiasTileDistribution()); - shuffle_tile(bias_shuffle_tmp, bias_tile); - store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); - block_sync_lds(); - auto biast_tile = load_tile(biast_lds_window); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = raw_scale * x + type_convert(y); -#else - x = scale * x + log2e_v * type_convert(y); -#endif - }, - st_acc, - biast_tile); - move_tile_window(bias_dram_window, {kM0, 0}); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); - sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - st_acc(i_j_idx) *= raw_scale; -#else - st_acc(i_j_idx) *= scale; -#endif - position_encoding.update(st_acc(i_j_idx), row, col); - }); - }); - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); -#endif - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - if(need_perpixel_check) - { - set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - const auto lse = load_tile(lse_dram_window); - - static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_lse == -numeric::infinity() - ? type_convert(0.f) - : raw_lse; - } - else - { - return raw_lse; - } - }; - - auto pt = SPTBlockTileType{}; - constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); - sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); -#endif - sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); - } - else - { - pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); - } -#else - pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); -#endif - }); - }); - - auto dot_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledOGradTRegBlockDescriptor()); - block_sync_lds(); - { - shuffle_tile(dot_shuffle_tmp, dot_prefetch); - store_tile(dot_lds_window, - dot_shuffle_tmp); // store the prefetch - } - move_tile_window(dot_dram_window, {0, kK1}); - - if constexpr(kHasDropout) - { - dropout.Run( - seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); - } - - // STAGE 3, P^T@OGrad^T Gemm1 - const auto pt_gemm = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, - pt); - } - else - { - return cast_tile(pt); - } - }(); - - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto dot = load_tile(dot_dram_window); // load next OGrad^T - block_sync_lds(); - gemm_1(dv_acc, - get_slice_tile(pt_gemm, - sequence{}, - sequence<(i_k1 + 1) * kK1, kN0>{}), - dot_lds_window); - block_sync_lds(); - shuffle_tile(dot_shuffle_tmp, dot); - store_tile(dot_lds_window, - dot_shuffle_tmp); // store the prefetch - - move_tile_window(dot_dram_window, {0, kK1}); - }); - } - auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile - // tail - { - block_sync_lds(); - gemm_1(dv_acc, - get_slice_tile( - pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), - dot_lds_window); - block_sync_lds(); - } - - // STAGE 4, OGrad@V Gemm2 - auto dpt_acc = SPGradTBlockTileType{}; - - { - move_tile_window(do_dram_window, {0, kK2}); - - clear_tile(dpt_acc); // Initialize PGrad^T - - store_tile(do_lds_window, do_block_tile); // LDS write 0 - do_block_tile = load_tile(do_dram_window); // global read 1 - } - - if constexpr(k2_loops > 2) - { - static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { - block_sync_lds(); - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile( - v, sequence<0, i_k2 * kK2>{}, sequence{})); - block_sync_lds(); - move_tile_window(do_dram_window, {0, kK2}); - - store_tile(do_lds_window, - do_block_tile); // LDS write i + 1 - do_block_tile = load_tile(do_dram_window); // global read i + 2 - }); - } - - const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile - { // tail - block_sync_lds(); - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile(v, - sequence<0, (k2_loops - 2) * kK2>{}, - sequence{})); - block_sync_lds(); - - store_tile(do_lds_window, do_block_tile); - block_sync_lds(); - - gemm_2(dpt_acc, - do_lds_window, - get_slice_tile(v, - sequence<0, (k2_loops - 1) * kK2>{}, - sequence{})); - } - - // STAGE 5, P^T(PGrad^T - D) - const auto d = load_tile(d_dram_window); - - auto dst = SPGradTBlockTileType{}; - constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); - sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - bool undrop_flag = pt[i_j_idx] >= 0; - dst(i_j_idx) = - pt[i_j_idx] * - (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); - }); - }); - - if constexpr(kHasBiasGrad) - { - const auto dbiast = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [&rp_undrop](const auto& x) { - return type_convert(x * rp_undrop); - }, - dst); - } - else - { - return cast_tile(dst); - } - }(); - store_tile(biast_lds_shuffle_window, dbiast); - block_sync_lds(); - auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); - auto dbiast_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeBiasTileDistribution()); - shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); - store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); - move_tile_window(dbias_dram_block_window, {kM0, 0}); - } - - // STAGE 6, SGrad^T@Q^T Gemm3 - auto qt_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledQTRegBlockDescriptor()); - block_sync_lds(); - { - shuffle_tile(qt_shuffle_tmp, qt_prefetch); - store_tile(qt_lds_window, - qt_shuffle_tmp); // store the prefetch - } - move_tile_window(qt_dram_window, {0, kK3}); - - const auto dst_gemm = cast_tile(dst); - - if constexpr(k3_loops > 1) - { - static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { - const auto qt = load_tile(qt_dram_window); // load next Q^T - block_sync_lds(); - gemm_3(dk_acc, - get_slice_tile(dst_gemm, - sequence{}, - sequence<(i_k3 + 1) * kK3, kN0>{}), - qt_lds_window); - block_sync_lds(); - shuffle_tile(qt_shuffle_tmp, qt); - store_tile(qt_lds_window, - qt_shuffle_tmp); // store the prefetch - - move_tile_window(qt_dram_window, {0, kK3}); - }); - } - // tail - { - block_sync_lds(); - gemm_3(dk_acc, - get_slice_tile( - dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), - qt_lds_window); - block_sync_lds(); - } - - // STAGE 7, SGrad@K^T Gemm4 - store_tile(ds_lds_window, dst_gemm); - - auto dq_acc = QGradBlockTileType{}; - clear_tile(dq_acc); // Initialize QGrad - - block_sync_lds(); - - static_for<0, k4_loops, 1>{}([&](auto i_k4) { - gemm_4(dq_acc, - get_slice_tile(ds_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{}), - get_slice_tile(kt_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{})); - }); - - // QGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dq_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); - } - const auto dq = cast_tile(dq_acc); - update_tile(dq_dram_block_window, dq); - - // move tile windows - move_tile_window(q_dram_block_window, {kM0, 0}); - move_tile_window(dq_dram_block_window, {kM0, 0}); - move_tile_window(do_dram_block_window, {kM0, 0}); - move_tile_window(lse_dram_window, {kM0}); - move_tile_window(d_dram_window, {kM0}); - } while(++i_total_loops < num_total_loop); - - // KGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dk_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); - } - // VGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); - } - - return ck_tile::make_tuple(dk_acc, dv_acc); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp deleted file mode 100644 index cc4e6304d0..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" - -namespace ck_tile { - -// This pipeline is v located in regs, k located in lds. -using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy = - BlockFmhaBwdPipelineDefaultPolicy; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp deleted file mode 100644 index 97487bb71e..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp +++ /dev/null @@ -1,692 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -template -struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using GemmDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using DDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using OGradDataType = remove_cvref_t; - using QGradDataType = remove_cvref_t; - using KGradDataType = remove_cvref_t; - using VGradDataType = remove_cvref_t; - using BiasGradDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - - static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; - - static constexpr bool kQLoadOnce = true; - static constexpr bool kQTLoadOnce = false; - static constexpr bool kKLoadOnce = true; - static constexpr bool kKTLoadOnce = false; - static constexpr bool kVLoadOnce = true; - static constexpr bool kOGradLoadOnce = true; - static constexpr bool kOGradTLoadOnce = false; - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; - static constexpr bool kHasDropout = Problem::kHasDropout; - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); - static constexpr index_t kAlignmentQGrad = - kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); - static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); - static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); - - static constexpr const char* name = "qs_ks_vr_dos"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, - const KDramBlockWindowTmp& k_dram_block_window_tmp, - const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, - const VDramBlockWindowTmp& v_dram_block_window_tmp, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, - const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - const OGradDramBlockWindowTmp& do_dram_block_window_tmp, - const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/, - const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, - const DDramBlockWindowTmp& d_dram_block_window_tmp, - const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, - const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, - FmhaMask mask, - PositionEncoding position_encoding, - float raw_scale, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - float scale, -#endif - float rp_undrop, - float scale_rp_undrop, - void* smem_ptr, - BlockDropout& dropout) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // Q tile in LDS - QDataType* q_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); - - // QT tile in LDS - auto qt_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT()); - auto qt_lds_window = - make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); - - // K tile in LDS - auto k_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - - // KT tile in LDS - auto kt_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptorAsKT()); - auto kt_lds_window = - make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGrad tile in LDS - OGradDataType* do_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeQ())); - auto do_lds = make_tensor_view( - do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); - auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); - - // OGradT tile in LDS - auto dot_lds = make_tensor_view( - do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT()); - auto dot_lds_window = - make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); - - // SGrad tile in LDS - GemmDataType* ds_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeOGrad())); - auto ds_lds = make_tensor_view( - ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); - auto ds_lds_window = - make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); - - // BiasT/BiasGradT tile in LDS, use the same size and layout - BiasDataType* biast_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeOGrad())); - auto biast_lds = make_tensor_view( - biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); - auto biast_lds_shuffle_window = - make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); - auto dbiast_lds_shuffle_window = - make_tile_window(biast_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeShuffledBiasTileDistribution()); - - static_assert(std::is_same_v, - "BiasDataType and BiasGradDataType should be the same!"); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); - constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); - constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); - constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); - - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - v_dram_block_window_tmp.get_window_origin(), - Policy::template MakeVInRegDramTileDistribution()); - - auto v = load_tile(v_dram_window); // persistent V register tile - - using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); - using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); - using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); - - // init VGrad & KGrad - auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; - auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; - - clear_tile(dv_acc); - clear_tile(dk_acc); - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - k_dram_block_window_tmp.get_window_origin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - __builtin_amdgcn_sched_barrier(0); - const auto k_origin = k_dram_window.get_window_origin(); - const auto [seqlen_q_start, seqlen_q_end] = - mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) - { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return ck_tile::make_tuple(dk_acc, dv_acc); - } - } - - auto k_block_tile = load_tile(k_dram_window); - - store_tile(k_lds_window, k_block_tile); // // persistent K in LDS - - auto q_dram_block_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto do_dram_block_window = - make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), - do_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto dq_dram_block_window = - make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), - dq_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, 0}); - - auto lse_dram_block_window = - make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - auto d_dram_block_window = - make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}); - - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_block_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N - - const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); - auto dbias_dram_block_window = - make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), - dbias_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N - - auto lse_dram_window = make_tile_window( - lse_dram_block_window.get_bottom_tensor_view(), - lse_dram_block_window.get_window_lengths(), - lse_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto d_dram_window = make_tile_window( - d_dram_block_window.get_bottom_tensor_view(), - d_dram_block_window.get_window_lengths(), - d_dram_block_window.get_window_origin(), - Policy::template MakeLSEDDramTileDistribution()); - - auto bias_dram_window = - make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), - bias_dram_block_window.get_window_lengths(), - bias_dram_block_window.get_window_origin(), - Policy::template MakeBiasTileDistribution()); - - auto biast_lds_window = - make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), - biast_lds_shuffle_window.get_window_lengths(), - biast_lds_shuffle_window.get_window_origin(), - Policy::template MakeBiasTTileDistribution()); - - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_q_start); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kM0 / kK1; - constexpr index_t k2_loops = kVHeaddim / kK2; - constexpr index_t k3_loops = kM0 / kK3; - constexpr index_t k4_loops = kN0 / kK4; - do - { - auto q_dram_window = make_tile_window( - q_dram_block_window.get_bottom_tensor_view(), - q_dram_block_window.get_window_lengths(), - q_dram_block_window.get_window_origin(), - Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for - // load - - auto do_dram_window = make_tile_window( - do_dram_block_window.get_bottom_tensor_view(), - do_dram_block_window.get_window_lengths(), - do_dram_block_window.get_window_origin(), - Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile - // window for load - - // STAGE 1, Q@K Gemm0 - auto st_acc = SPTBlockTileType{}; - - auto q_block_tile = load_tile(q_dram_window); - clear_tile(st_acc); // Initialize S^T - store_tile(q_lds_window, q_block_tile); // LDS write - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - - if constexpr(k0_loops > 1) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - block_sync_lds(); - gemm_0(st_acc, - get_slice_tile(q_lds_window, - sequence<0, i_k0 * kK0>{}, - sequence{}), - get_slice_tile(k_lds_window, - sequence<0, i_k0 * kK0>{}, - sequence{})); - block_sync_lds(); - }); - } - - auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile - { // tail - block_sync_lds(); - gemm_0(st_acc, - get_slice_tile(q_lds_window, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - get_slice_tile(k_lds_window, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{})); - block_sync_lds(); - } - - // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - block_sync_lds(); - auto bias_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBiasTileDistribution()); - shuffle_tile(bias_shuffle_tmp, bias_tile); - store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); - block_sync_lds(); - auto biast_tile = load_tile(biast_lds_window); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = raw_scale * x + type_convert(y); -#else - x = scale * x + log2e_v * type_convert(y); -#endif - }, - st_acc, - biast_tile); - move_tile_window(bias_dram_window, {kM0, 0}); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); - sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - st_acc(i_j_idx) *= raw_scale; -#else - st_acc(i_j_idx) *= scale; -#endif - position_encoding.update(st_acc(i_j_idx), row, col); - }); - }); - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); -#endif - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto q_origin = q_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - if(need_perpixel_check) - { - set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - const auto lse = load_tile(lse_dram_window); - - static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_lse == -numeric::infinity() - ? type_convert(0.f) - : raw_lse; - } - else - { - return raw_lse; - } - }; - - auto pt = SPTBlockTileType{}; - constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); - sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); -#endif - sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); - } - else - { - pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); - } -#else - pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); -#endif - }); - }); - - if constexpr(kHasDropout) - { - dropout.Run( - seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); - } - - // STAGE 3, P^T@OGrad^T Gemm1 - block_sync_lds(); - store_tile(do_lds_window, do_block_tile); // store the prefetch - - const auto pt_gemm = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, - pt); - } - else - { - return cast_tile(pt); - } - }(); - - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - block_sync_lds(); - gemm_1(dv_acc, - get_slice_tile( - pt_gemm, sequence{}, sequence<(i_k1 + 1) * kK1, kN0>{}), - get_slice_tile(dot_lds_window, - sequence<0, i_k1 * kK1>{}, - sequence{})); - block_sync_lds(); - }); - - // STAGE 4, OGrad@V Gemm2 - auto dpt_acc = SPGradTBlockTileType{}; - clear_tile(dpt_acc); // Initialize PGrad^T - - static_for<0, k2_loops, 1>{}([&](auto i_k2) { - block_sync_lds(); - gemm_2(dpt_acc, - get_slice_tile(do_lds_window, - sequence<0, i_k2 * kK2>{}, - sequence{}), - get_slice_tile( - v, sequence<0, i_k2 * kK2>{}, sequence{})); - block_sync_lds(); - }); - - // STAGE 5, P^T(PGrad^T - D) - const auto d = load_tile(d_dram_window); - - auto dst = SPGradTBlockTileType{}; - constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); - sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - bool undrop_flag = pt[i_j_idx] >= 0; - dst(i_j_idx) = - pt[i_j_idx] * - (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); - }); - }); - - if constexpr(kHasBiasGrad) - { - const auto dbiast = [&]() { - if constexpr(kHasDropout) - { - return tile_elementwise_in( - [&rp_undrop](const auto& x) { - return type_convert(x * rp_undrop); - }, - dst); - } - else - { - return cast_tile(dst); - } - }(); - store_tile(biast_lds_shuffle_window, dbiast); - block_sync_lds(); - auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); - auto dbiast_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeBiasTileDistribution()); - shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); - store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); - move_tile_window(dbias_dram_block_window, {kM0, 0}); - } - - // STAGE 6, SGrad^T@Q^T Gemm3 - block_sync_lds(); - const auto dst_gemm = cast_tile(dst); - - static_for<0, k3_loops, 1>{}([&](auto i_k3) { - block_sync_lds(); - gemm_3(dk_acc, - get_slice_tile( - dst_gemm, sequence{}, sequence<(i_k3 + 1) * kK3, kN0>{}), - get_slice_tile(qt_lds_window, - sequence<0, i_k3 * kK3>{}, - sequence{})); - block_sync_lds(); - }); - - // STAGE 7, SGrad@K^T Gemm4 - store_tile(ds_lds_window, dst_gemm); - - auto dq_acc = QGradBlockTileType{}; - clear_tile(dq_acc); // Initialize QGrad - - block_sync_lds(); - - static_for<0, k4_loops, 1>{}([&](auto i_k4) { - gemm_4(dq_acc, - get_slice_tile(ds_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{}), - get_slice_tile(kt_lds_window, - sequence<0, i_k4 * kK4>{}, - sequence{})); - }); - - // QGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dq_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); - } - const auto dq = cast_tile(dq_acc); - update_tile(dq_dram_block_window, dq); - - // move tile windows - move_tile_window(q_dram_block_window, {kM0, 0}); - move_tile_window(dq_dram_block_window, {kM0, 0}); - move_tile_window(do_dram_block_window, {kM0, 0}); - move_tile_window(lse_dram_window, {kM0}); - move_tile_window(d_dram_window, {kM0}); - } while(++i_total_loops < num_total_loop); - - // KGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, - dk_acc); - } - else - { - tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); - } - // VGrad Scale - if constexpr(kHasDropout) - { - tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); - } - - return ck_tile::make_tuple(dk_acc, dv_acc); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp deleted file mode 100644 index ac81990e07..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" - -namespace ck_tile { - -// This pipeline is v located in regs, q & k & do located in lds. -using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy = - BlockFmhaBwdPipelineDefaultPolicy; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index d867772a1f..4143c34ff8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -11,6 +11,8 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" @@ -18,60 +20,215 @@ namespace ck_tile { -template struct BlockFmhaBwdPipelineDefaultPolicy { - static constexpr bool QLoadOnce = - QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once - static constexpr bool QTLoadOnce = - QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once - static constexpr bool KLoadOnce = - KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once - static constexpr bool KTLoadOnce = - KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once - static constexpr bool VLoadOnce = - VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once - static constexpr bool OGradLoadOnce = - OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once - static constexpr bool OGradTLoadOnce = - OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + false, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::OGradDataType, + typename Problem::VDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}), + false, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + false>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } // these are for global load template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - using QDataType = remove_cvref_t; - return 16 / sizeof(QDataType); + using QDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); + constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); + + constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() { - using KDataType = remove_cvref_t; - return 16 / sizeof(KDataType); + using KDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); + constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); + + constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - if constexpr(VLoadOnce) - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - } - else - { - using VDataType = remove_cvref_t; - return 16 / sizeof(VDataType); - } + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); + constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; + + return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels; } template @@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() { - using OGradDataType = remove_cvref_t; - return 16 / sizeof(OGradDataType); + using OGradDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); + constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); + + constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; } template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - using CWarpDstr = typename WG::CWarpDstr; - constexpr auto vec = - CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); - return vec; + using BiasDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType); + constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType); + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; } template @@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK3; - }(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; - else - return 2; + return total_pixels / GetAlignmentQ(); } template CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KTLoadOnce) - return Problem::BlockFmhaShape::kN0; - else - return Problem::BlockFmhaShape::kK4; - }(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; - else - return 2; + return total_pixels / GetAlignmentK(); } template CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK1; - }(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; - else - return 2; + return total_pixels / GetAlignmentOGrad(); } template @@ -193,554 +344,56 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; - // TODO: not correct! - if constexpr(total_pixels > 32) - return 8; - else - return 4; - } - - // these are for lds - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() - { - // TODO: this is for 3d layout - using QDataType = remove_cvref_t; - return 16 / sizeof(QDataType); + return total_pixels / GetAlignmentBias(); } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc() { - // TODO: this is for 3d layout - using KDataType = remove_cvref_t; - return 16 / sizeof(KDataType); + using AccDataType = remove_cvref_t; + return 16 / sizeof(AccDataType); } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad() { - // TODO: this is for 3d layout - using VDataType = remove_cvref_t; - return 16 / sizeof(VDataType); + return GetAlignmentPostQGradAcc(); } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - // TODO: this is for 3d layout - using BiasDataType = remove_cvref_t; - return 16 / sizeof(BiasDataType); - } + constexpr index_t kBlockSize = Problem::kBlockSize; - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() - { - // TODO: this is for 3d layout - using OGradDataType = remove_cvref_t; - return 16 / sizeof(OGradDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() - { - // TODO: this is for 3d layout - using GemmDataType = remove_cvref_t; - return 16 / sizeof(GemmDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() - { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - - return v_block_dstr; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() - { - constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto x_lds_block_desc = transform_tensor_descriptor( - x_lds_block_desc_0, - make_tuple(make_pass_through_transform(MNPerBlock), - make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return x_lds_block_desc; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT() - { - constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto xt_lds_block_desc = transform_tensor_descriptor( - x_lds_block_desc_0, - make_tuple(make_pass_through_transform(MNPerBlock), - make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return xt_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() - { - static_assert(PixelsPerRow % KPack == 0); - constexpr index_t NPerRow = PixelsPerRow / KPack; - static_assert(MNPerBlock % NPerRow == 0); - static_assert(KPerBlock % KPack == 0); - - constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}), - make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto xt_lds_block_desc = transform_tensor_descriptor( - xt_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return xt_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); - constexpr index_t kKPack = GetSmemKPackQ(); - - return MakeXLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); - constexpr index_t kKPack = GetSmemKPackQ(); - - return MakeXLdsBlockDescriptorAsXT(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); - constexpr index_t kKPack = GetSmemKPackK(); - - return MakeXLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); - constexpr index_t kKPack = GetSmemKPackK(); - - return MakeXLdsBlockDescriptorAsXT(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; - constexpr index_t kKPack = GetSmemKPackV(); - - return MakeXLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradLoadOnce) - return Problem::BlockFmhaShape::kVHeaddim; - else - return Problem::BlockFmhaShape::kK2; - }(); - constexpr index_t kKPack = GetSmemKPackOGrad(); - - return MakeXLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradLoadOnce) - return Problem::BlockFmhaShape::kVHeaddim; - else - return Problem::BlockFmhaShape::kK2; - }(); - constexpr index_t kKPack = GetSmemKPackOGrad(); - - return MakeXLdsBlockDescriptorAsXT(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPack = GetSmemKPackSGrad(); - - return MakeXLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() - { - using QDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType); - constexpr index_t kKPack = GetSmemKPackQ(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK3; - }(); - - return MakeXTLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() - { - using KDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); - constexpr index_t kKPack = GetSmemKPackK(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KTLoadOnce) - return Problem::BlockFmhaShape::kN0; - else - return Problem::BlockFmhaShape::kK4; - }(); - - return MakeXTLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() - { - using OGradDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType); - constexpr index_t kKPack = GetSmemKPackOGrad(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK1; - }(); - - return MakeXTLdsBlockDescriptor(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() - { - using BiasDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType); - constexpr index_t kKPack = GetSmemKPackBias(); - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kMPerBlock % kKPack == 0); - - constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}), - make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto biast_lds_block_desc = transform_tensor_descriptor( - biast_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return biast_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * - MakeQLdsBlockDescriptor().get_element_space_size(); - return smem_size_q; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() - { - constexpr index_t smem_size_qt = [&]() { - if constexpr(QLoadOnce && !QTLoadOnce) - return 0; - else - return sizeof(typename Problem::QDataType) * - MakeQTLdsBlockDescriptor().get_element_space_size(); - }(); - return smem_size_qt; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * - MakeKLdsBlockDescriptor().get_element_space_size(); - return smem_size_k; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() - { - constexpr index_t smem_size_kt = [&]() { - if constexpr(KLoadOnce && !KTLoadOnce) - return 0; - else - return sizeof(typename Problem::KDataType) * - MakeKTLdsBlockDescriptor().get_element_space_size(); - }(); - return smem_size_kt; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() - { - constexpr index_t smem_size_v = [&]() { - if constexpr(VLoadOnce) - return 0; - else - return sizeof(typename Problem::VDataType) * - MakeVLdsBlockDescriptor().get_element_space_size(); - }(); - return smem_size_v; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() - { - constexpr index_t smem_size_do = - sizeof(typename Problem::OGradDataType) * - MakeOGradLdsBlockDescriptor().get_element_space_size(); - return smem_size_do; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() - { - constexpr index_t smem_size_dot = [&]() { - if constexpr(OGradLoadOnce && !OGradTLoadOnce) - return 0; - else - return sizeof(typename Problem::OGradDataType) * - MakeOGradTLdsBlockDescriptor().get_element_space_size(); - }(); - return smem_size_dot; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() - { - constexpr index_t smem_size_ds = - sizeof(typename Problem::GemmDataType) * - MakeSGradLdsBlockDescriptor().get_element_space_size(); - return smem_size_ds; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() - { - constexpr index_t smem_size_bias = [&]() { - if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - return sizeof(typename Problem::BiasDataType) * - MakeBiasTLdsBlockDescriptor().get_element_space_size(); - else - return 0; - }(); - return smem_size_bias; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - constexpr index_t smem_size_q = GetSmemSizeQ(); - constexpr index_t smem_size_qt = GetSmemSizeQT(); - constexpr index_t smem_size_k = GetSmemSizeK(); - constexpr index_t smem_size_kt = GetSmemSizeKT(); - constexpr index_t smem_size_v = GetSmemSizeV(); - constexpr index_t smem_size_do = GetSmemSizeOGrad(); - constexpr index_t smem_size_dot = GetSmemSizeOGradT(); - constexpr index_t smem_size_ds = GetSmemSizeSGrad(); - constexpr index_t smem_size_bias = GetSmemSizeBias(); - constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); - - index_t smem_size = 0; - - if constexpr(QLoadOnce && OGradLoadOnce) - smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + - smem_size_transpose; // 1~4 & 10 - else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) - smem_size += smem_size_q + smem_size_qt + - max(smem_size_do, - smem_size_dot, - smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy - else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) - smem_size += smem_size_do + smem_size_dot + - max(smem_size_q, - smem_size_qt, - smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy - else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) - smem_size += max(smem_size_q, - smem_size_qt, - smem_size_do, - smem_size_dot, - smem_size_transpose); // 9/13 TODO: Multiple buffers strategy - - // 14/15 needs to be adjusted - if constexpr(KLoadOnce) - smem_size += (smem_size_k + smem_size_kt); // 1~13 - else - smem_size = - max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy - - return max(smem_size, smem_size_v); // 15 - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() - { - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - - constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; - constexpr index_t N0 = NWarp; - - constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; - constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N1 = get_warp_size() / K0; + constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = kNPerBlock / (N1 * N0); return make_static_tile_distribution( - tile_distribution_encoding, - tuple>, - tuple, sequence<1, 0>>, - tuple, sequence<3, 1>>, - sequence<1, 1, 1>, - sequence<0, 2, 4>>{}); + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() { - using VDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; - constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K1 = GetAlignmentV(); constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / (N2 * N1); @@ -759,56 +412,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); + constexpr index_t M1 = get_warp_size() / K0; + constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M2 = kMPerBlock / (M1 * M0); return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<1, 0>>, sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KLoadOnce) - return Problem::BlockFmhaShape::kQKHeaddim; - else - return Problem::BlockFmhaShape::kK0; - }(); - - constexpr index_t K1 = GetAlignmentK(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + sequence<2, 1>>{}); } template @@ -817,27 +435,72 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradLoadOnce) - return Problem::BlockFmhaShape::kVHeaddim; - else - return Problem::BlockFmhaShape::kK2; - }(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); + constexpr index_t M1 = get_warp_size() / K0; + constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M2 = kMPerBlock / (M1 * M0); return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<1, 0>>, sequence<1, 2>, - sequence<0, 1>>{}); + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + // Duplicate dimension + constexpr index_t N0 = NWarp; + constexpr index_t N1 = + (get_warp_size() / kMPerBlock) > 1 ? (get_warp_size() / kMPerBlock) : 1; + + constexpr index_t M0 = MWarp; + constexpr index_t M1 = (get_warp_size() / kMPerBlock) > 1 ? kMPerBlock : get_warp_size(); + constexpr index_t M2 = + (get_warp_size() / kMPerBlock) > 1 ? 1 : (kMPerBlock / get_warp_size()); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M1 = get_warp_size() / N0; + constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M2 = kMPerBlock / (M1 * M0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); } template @@ -881,463 +544,1377 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template - CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution() { + using AccDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK3; - }(); + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; - constexpr index_t N1 = GetTransposedAlignmentQ(); - constexpr index_t N0 = kNPerBlock / N1; // P + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackQ(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); + tuple, sequence, sequence>, + tuple, sequence<2, 3>>, + tuple, sequence<2, 0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(QTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK3; - }(); + using AccDataType = remove_cvref_t; - constexpr index_t N1 = GetTransposedAlignmentQ(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackQ(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; + + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; + + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, sequence<1, 2>, - sequence<1, 3>>{}); + sequence<0, 1>>{}); + } + + // these are for lds + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + return GetAlignmentQ(); } template - CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQT() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KTLoadOnce) - return Problem::BlockFmhaShape::kN0; - else - return Problem::BlockFmhaShape::kK4; - }(); - - constexpr index_t N1 = GetTransposedAlignmentK(); - constexpr index_t N0 = kNPerBlock / N1; // P - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackK(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); + return GetTransposedAlignmentQ(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(KTLoadOnce) - return Problem::BlockFmhaShape::kN0; - else - return Problem::BlockFmhaShape::kK4; - }(); - - constexpr index_t N1 = GetTransposedAlignmentK(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackK(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); + return GetAlignmentK(); } template - CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackKT() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK1; - }(); - - constexpr index_t N1 = GetTransposedAlignmentOGrad(); - constexpr index_t N0 = kNPerBlock / N1; // P - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackOGrad(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); + return GetTransposedAlignmentK(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPerBlock = [&]() { - if constexpr(OGradTLoadOnce) - return Problem::BlockFmhaShape::kM0; - else - return Problem::BlockFmhaShape::kK1; - }(); - - constexpr index_t N1 = GetTransposedAlignmentOGrad(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackOGrad(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); + return GetAlignmentV(); } template - CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + { + return GetAlignmentBias(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBiasT() + { + return GetTransposedAlignmentBias(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() + { + return GetAlignmentOGrad(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGradT() + { + return GetTransposedAlignmentOGrad(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() + { + // TODO: this is for 3d layout + using GemmDataType = remove_cvref_t; + return 16 / sizeof(GemmDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) + constexpr auto MNLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + x_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return x_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr auto kBlockSize = Problem::kBlockSize; + + constexpr auto MN0 = MNPerBlock / KPack; + constexpr auto MN1 = KPack; + + constexpr auto KThreadWrite = kBlockSize / MN0; + constexpr auto K0Number = KPerBlock / KPackT; + constexpr auto K0PerThreadWrite = K0Number / KThreadWrite; + constexpr auto KThreadRead = get_warp_size() / MNPerXDL; // assume 32x32x8 mfma + constexpr auto K0PerThreadRead = K0Number / KThreadRead; + + constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mnpair<=n0 + constexpr auto mnpair = + (KPackT * MNPerXDL * 2 > 128) + ? 1 + : ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2)); + + constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + KPackT), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor( + xt_lds_block_desc_raw, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(KPackT)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor( + xt_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(KPackT)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + xt_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = GetSmemKPackK(); - constexpr index_t N1 = GetTransposedAlignmentBias(); - constexpr index_t N0 = kNPerBlock / N1; // P + return MakeXLdsBlockDescriptor(); + } - constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t M3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackBias(); - static_assert(kKPack % M3 == 0); - constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave - constexpr index_t M1 = get_warp_size() / (M2 * N0); - constexpr index_t M0 = kBlockSize / get_warp_size(); - static_assert(kMPerBlock == M0 * M1 * M2 * M3); + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t kVPack = GetSmemKPackV(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = GetTransposedAlignmentK(); + constexpr index_t N1 = get_warp_size() / K0; + constexpr index_t N0 = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 1>>, - tuple, sequence<1, 0, 2>>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor() + { + // Hold all data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKPackT = GetSmemKPackKT(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor(); + + return transform_tensor_descriptor( + shuffled_k_lds_block_desc, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto kt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, sequence<1, 2>, - sequence<3, 1>>{}); + sequence<0, 0>>{}; + + constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode); + + return kt_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = GetTransposedAlignmentQ(); + constexpr index_t N1 = get_warp_size() / K0; + constexpr index_t N0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor() + { + // Hold full block data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kKPackT = GetSmemKPackQT(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor() + { + // Hold full block data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; + + auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor(); + + return transform_tensor_descriptor( + shuffled_q_lds_block_desc, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto qt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode); + + return qt_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dst_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode); + + return dst_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + using LSEDType = remove_cvref_t; + constexpr index_t kMPack = 16 / sizeof(LSEDType); + + constexpr auto lsed_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}), + make_tuple(number<1>{}), + number{}, + number<1>{}); + + return lsed_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + // M4 *2 and M2 /2 when swizzle mode enabled + constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2; + // constexpr index_t SwizzleConfig = 1; + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() + { + // Hold full block data + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto do_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode); + + return do_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t K1 = GetAlignmentOGrad(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = GetTransposedAlignmentOGrad(); + constexpr index_t N1 = get_warp_size() / K0; + constexpr index_t N0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor() + { + // Hold all data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kKPackT = GetSmemKPackOGradT(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor() + { + // Hold all data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; + auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor(); + + return transform_tensor_descriptor( + shuffled_do_lds_block_desc, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + // constexpr index_t kNPerBlock = 32; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dot_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode); + + return dot_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto pt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode); + + return pt_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPack = GetSmemKPackSGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto ds_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode); + + return ds_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out, + const PInTensor& p_in) + { + if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + auto pt_warp_tensor = + make_static_distributed_tensor(CWarpDstr{}); + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + pt_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + pt_warp_tensor.get_thread_buffer()); + }); + }); + } + else + { + pt_out.get_thread_buffer() = p_in.get_thread_buffer(); + } + } + + template + CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out, + const SGradInTensor& ds_in) + { + if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + auto dst_warp_tensor = + make_static_distributed_tensor(CWarpDstr{}); + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + dst_out.set_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), + dst_warp_tensor.get_thread_buffer()); + }); + }); + } + else + { + dst_out.get_thread_buffer() = ds_in.get_thread_buffer(); + } } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t N1 = GetTransposedAlignmentBias(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t M3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackBias(); - static_assert(kKPack % M3 == 0); - constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave - constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t N1 = GetAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = GetTransposedAlignmentBias(); + constexpr index_t M1 = get_warp_size() / N0; constexpr index_t M0 = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 1>>, - tuple, sequence<1, 0, 2>>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, sequence<2, 1>, - sequence<1, 3>>{}); + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor() + { + // Hold full block data + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t kKPack = GetSmemKPackBias(); + constexpr index_t kKPackT = GetSmemKPackBiasT(); + + return MakeXTLdsBlockDescriptor(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasSTileDistribution() { using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); return c_block_tensor_type::get_tile_distribution(); } template - CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(); + return smem_size_q; + } - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; - } + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQT() + { + constexpr index_t smem_size_qt = + sizeof(typename Problem::QDataType) * + MakeShuffledQLdsWriteBlockDescriptor().get_element_space_size(); + + return smem_size_qt; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK() + { + constexpr index_t smem_size_k = + sizeof(typename Problem::KDataType) * + MakeKLdsWriteBlockDescriptor().get_element_space_size(); + return smem_size_k; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeKT() + { + constexpr index_t smem_size_kt = + sizeof(typename Problem::KDataType) * + MakeKTLdsReadBlockDescriptor().get_element_space_size(); + return smem_size_kt; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE() + { + constexpr index_t smem_size_lse = + sizeof(typename Problem::LSEDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + return smem_size_lse; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD() + { + constexpr index_t smem_size_d = + sizeof(typename Problem::DDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + return smem_size_d; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV() + { + constexpr index_t smem_size_v = + sizeof(typename Problem::VDataType) * + MakeVLdsWriteBlockDescriptor().get_element_space_size(); + return smem_size_v; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad() + { + constexpr index_t smem_size_do = + sizeof(typename Problem::OGradDataType) * + MakeOGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_do; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGradT() + { + constexpr index_t smem_size_dot = + sizeof(typename Problem::OGradDataType) * + MakeShuffledOGradLdsWriteBlockDescriptor().get_element_space_size(); + return smem_size_dot; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad() + { + constexpr index_t smem_size_ds = + sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_ds; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias() + { + constexpr index_t smem_size_bias = [&]() { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasLdsBlockDescriptor().get_element_space_size(); + else + return 0; }(); - - using BlockGemmPolicy = - BlockGemmASmemBSmemCRegV1CustomPolicy; - - return BlockGemmASmemBSmemCRegV1{}; + return smem_size_bias; } template - CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_qt = GetSmemSizeQT(); + constexpr index_t smem_size_lse = GetSmemSizeLSE(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_kt = GetSmemSizeKT(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_dot = GetSmemSizeOGradT(); + constexpr index_t smem_size_d = GetSmemSizeD(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true>; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - return BlockGemmARegBSmemCRegV1{}; + constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; + constexpr index_t smem_size_stage0_1 = smem_size_v; + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + smem_size_do + smem_size_lse + smem_size_d + + max(smem_size_bias, smem_size_ds); + + return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1); } - template - CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + template + struct HotLoopScheduler { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using Problem = Problem_; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; - } - }(); + template + CK_TILE_DEVICE static constexpr void GemmStagedScheduler() + { + } - using BlockGemmPolicy = - BlockGemmASmemBRegCRegV1CustomPolicy; + template <> + CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>() + { + // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load + // Comp: Q x K + constexpr index_t VMEM_READ_INST = + Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ; + constexpr index_t LDS_READ_INST = OGradT_LDS_READ; + constexpr index_t MFMA_INST = Gemm0MFMA; - return BlockGemmASmemBRegCRegV1{}; - } + // Evenly distributed to relieve SQ->TA FIFO pressure + constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST; + constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST; + // To hide instruction issue latency + constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST; - // template - // CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() - // { - // using BlockGemmProblem = - // BlockGemmPipelineProblem>; - // constexpr auto warp_gemm = []() { - // if constexpr(std::is_same_v && - // std::is_same_v && - // std::is_same_v) - // { - // return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; - // } - // else if constexpr(std::is_same_v && - // std::is_same_v && - // std::is_same_v) - // { - // return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; - // } - // }(); + static_for<0, VMEM_READ_INST, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) { + ignore = j; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read + }); + }); + static_for<0, MFMA_Remainder, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read + }); + } - // using BlockGemmPolicy = - // BlockGemmASmemBSmemCRegV1CustomPolicy; + template <> + CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>() + { + // Mem: Q^T LDS load + // Comp: OGrad x V + constexpr index_t LDS_READ_INST = QT_LDS_READ; + constexpr index_t MFMA_INST = Gemm1MFMA; - // return BlockGemmASmemBSmemCRegV1{}; - // } + // To hide instruction issue latency + constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST; - template - CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + static_for<0, MFMA_INST, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read + }); + } - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), - true>; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - return BlockGemmARegBSmemCRegV1{}; - } + template <> + CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>() + { + // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store + // Comp: PT x OGrad + constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE + + OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE; + constexpr index_t MFMA_INST = Gemm2MFMA; - template - CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + // To hide instruction issue latency + constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST; - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), - true>; - using BlockGemmPolicy = - BlockGemmASmemBSmemCRegV1CustomPolicy; - return BlockGemmASmemBSmemCRegV1{}; - } + static_for<0, MFMA_INST, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write + }); + } + + template <> + CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>() + { + // Mem: SGradT LDS store, SGrad, Q, LSE LDS load. + // Comp: SGradT x QT + constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE; + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ; + constexpr index_t MFMA_INST = Gemm3MFMA; + + // To hide instruction issue latency + constexpr index_t LDS_WRITE_PER_MFMA = + LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1; + constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA; + + constexpr index_t LDS_READ_PER_MFMA = + (MFMA_INST - MFMA_INST_LDS_WRITE) > 0 + ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0 + ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) + : 1 + : 0; + + static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write + }); + + static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + }); + } + + template <> + CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>() + { + // Mem: SGrad, OGrad, D LDS load. + // Comp: SGrad x KT + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ; + constexpr index_t MFMA_INST = Gemm4MFMA; + + // To hide instruction issue latency + constexpr index_t LDS_READ_PER_MFMA = + LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1; + + static_for<0, MFMA_INST, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + }); + } + + private: + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; + static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; + static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; + + static constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static constexpr index_t WarpGemmN = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); + static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8; + static constexpr index_t Gemm4MWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + static constexpr index_t Gemm4NWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + // Compute + static constexpr index_t Gemm0MFMA = + kM0 * kN0 * kQKHeaddim / + (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm1MFMA = + kM0 * kN0 * kVHeaddim / + (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kN0 * kVHeaddim * kM0 / + (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm3MFMA = + kN0 * kQKHeaddim * kM0 / + (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm4MFMA = + kM0 * kQKHeaddim * kN0 / + (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + + // VMEM + static constexpr index_t Q_VMEM_READ = + kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t OGrad_VMEM_READ = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t LSE_VMEM_READ = 1; + static constexpr index_t D_VMEM_READ = 1; + + // LDS Read + static constexpr index_t OGradT_LDS_READ = + kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + static constexpr index_t QT_LDS_READ = + kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + static constexpr index_t SGradT_LDS_READ_P1 = + kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + static constexpr index_t Q_LDS_READ = + kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); + static constexpr index_t SGradT_LDS_READ_P2 = + kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + static constexpr index_t OGrad_LDS_READ = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); + + // LDS Write + static constexpr index_t Q_LDS_WRITE = + kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ(); + static constexpr index_t QT_LDS_WRITE = + kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ(); + static constexpr index_t OGrad_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t OGradT_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad(); + static constexpr index_t LSE_LDS_WRITE = 1; + static constexpr index_t D_LDS_WRITE = 1; + static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; + }; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp index a54a9fcb32..27f58ef2f8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp @@ -8,9 +8,8 @@ namespace ck_tile { // This class is used for codegen pattern matching enum class BlockFmhaBwdPipelineEnum { - KSKTSVR = 0, - QSKSVROGradS, - KSVR, + KRKTRVR_IGLP = 0, + KRKTRVR, }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 7b787e9f36..c4c4a745a7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -24,7 +24,9 @@ template struct BlockFmhaBwdPipelineProblem { @@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem using BiasGradDataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kIsDeterministic = kIsDeterministic_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; @@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; - static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; @@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaBwdConvertQGradPipelineProblem +{ + using AccDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN0 = kN0_; + static constexpr index_t kQKHeaddim = kQKHeaddim_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kIsDeterministic = kIsDeterministic_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 9a9196f273..e6837e2ebe 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -106,4 +106,14 @@ struct TileFmhaBwdOGradDotOTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaBwdConvertQGradTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index a89536e6eb..dd313c5480 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -5,6 +5,9 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp new file mode 100644 index 0000000000..9a5c2aae5c --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmARegBRegCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + // M->N Warp + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // check ABC-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "A distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "B distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor, b_block_tensor); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..9d494c2831 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmARegBRegCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..b849c48daf --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmARegBRegCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBRegCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 84883d6ed8..beab457b90 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - constexpr index_t MPerBlock = BlockGemmShape::kM; - constexpr index_t NPerBlock = BlockGemmShape::kN; - constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - // KPerBlock == BlockGemmShape::kK, - // "wrong!"); + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index 65ce1a9b8f..3d142df4d4 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1 std::is_same_v>, "wrong!"); - // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; - // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; - // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - constexpr index_t MPerBlock = BlockGemmShape::kM; - constexpr index_t NPerBlock = BlockGemmShape::kN; - constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - // KPerBlock == BlockGemmShape::kK, - // "wrong!"); + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 5b4419b79f..7ca4a697a7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA>; @@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< WarpGemmAtrributeMfmaIterateK_SwizzleA>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index fd5b004d36..d80e5198e6 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - reinterpret_cast(a_vec) + reinterpret_cast(a_vec) .template get_as()[iKIter], - reinterpret_cast(b_vec) + reinterpret_cast(b_vec) .template get_as()[iKIter]); }); } @@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK // c = a * b auto c_vec = Impl{}( - reinterpret_cast(a_vec).template get_as()[I0], - reinterpret_cast(b_vec).template get_as()[I0]); + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); // c += a * b static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - reinterpret_cast(a_vec) + reinterpret_cast(a_vec) .template get_as()[iKIter], - reinterpret_cast(b_vec) + reinterpret_cast(b_vec) .template get_as()[iKIter]); }); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 8d12130308..99cd5d787e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -15,7 +15,8 @@ template + bool TransposeC, + bool SwizzleA = false> struct WarpGemmMfmaDispatcher; // clang-format off @@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; + // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; @@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; + // fp8 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; @@ -58,8 +65,15 @@ template -using WarpGemmMfmaDispatcher = typename impl:: - WarpGemmMfmaDispatcher::Type; + bool TransposeC, + bool SwizzleA = false> +using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; } // namespace ck_tile diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 0cd54c7788..e4efae6173 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -100,18 +100,17 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) -if(GPU_TARGETS MATCHES "gfx94") - list(APPEND GEMM_INSTANCES - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) -endif() +list(APPEND GEMM_INSTANCES + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index cc4ce76606..5b6e985e59 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -51,8 +51,7 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -if(GPU_TARGETS MATCHES "gfx94") - list(APPEND GEMM_UNIVERSAL_INSTANCES +list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -123,6 +122,6 @@ if(GPU_TARGETS MATCHES "gfx94") set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -endif() + add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp index 12994aeecd..3b930e9894 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -35,12 +35,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#ifdef __gfx94__ + //Only enable these instances on gfx94x // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, @@ -55,6 +56,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 4, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 7bb7e71c54..29fa8fa3c5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -15,7 +15,7 @@ set(GROUPED_CONV3D_BWD_DATA wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp) -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp) endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 1a9c455220..8e939c15a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -30,7 +30,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt index 5781f07080..329e8e4c7f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt @@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt index be54eb4adf..9a42d1ec3a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt @@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 9bb6d807e6..5a346f1e94 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -43,22 +43,22 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) endif() -if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) endif() -if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) endif() -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) list(APPEND GROUPED_CONV3D_FWD diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 356aec7a08..5318de5e8b 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -136,9 +136,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::string best_op_name; - float best_avg_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + ck::index_t best_split_k = 1; // profile device Conv instances bool all_pass = true; @@ -167,99 +168,115 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + std::vector split_k_list = {1, 2, 4, 8, 16, 32, 64, 128}; + + if(split_k > 0) + { + split_k_list = {split_k}; + } + for(auto& op_ptr : op_ptrs) { - auto argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op, - split_k); - - const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); - DeviceMem workspace_dev(workspace_sz); - op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) + for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++) { - // using atomic add, so need to reset input - wei_device_buf.SetZero(); + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k_list[split_k_id]); - std::string op_name = op_ptr->GetTypeString(); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - float avg_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - best_op_name = op_name; - best_tflops = tflops; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - } + // using atomic add, so need to reset input + wei_device_buf.SetZero(); - if(do_verification) - { - wei_device_buf.FromDevice(weight_device_result.mData.data()); + std::string op_name = op_ptr->GetTypeString(); - bool pass = ck::utils::check_err(weight_device_result, weight_host_result); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - if(!pass) + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK " + << split_k_list[split_k_id] << std::endl; + + if(tflops > best_tflops) { - std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_split_k = split_k_list[split_k_id]; } - all_pass &= pass; - - if(do_log) + if(do_verification) { - LogRangeAsType(std::cout << "output : ", output.mData, ",") << std::endl; - ; - LogRangeAsType( - std::cout << "weight (device): ", weight_device_result.mData, ",") - << std::endl; - ; - LogRangeAsType( - std::cout << "weight (host): ", weight_host_result.mData, ",") - << std::endl; - ; - LogRangeAsType(std::cout << "input: ", input.mData, ",") << std::endl; - ; + wei_device_buf.FromDevice(weight_device_result.mData.data()); + + bool pass = ck::utils::check_err(weight_device_result, weight_host_result); + + if(!pass) + { + std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; + } + + all_pass &= pass; + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", output.mData, ",") + << std::endl; + ; + LogRangeAsType( + std::cout << "weight (device): ", weight_device_result.mData, ",") + << std::endl; + ; + LogRangeAsType( + std::cout << "weight (host): ", weight_host_result.mData, ",") + << std::endl; + ; + LogRangeAsType(std::cout << "input: ", input.mData, ",") + << std::endl; + ; + } } } - } - else - { - std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } } } std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " + << best_split_k << std::endl; return all_pass; } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 6ed7cf5e48..7dd75a5e0a 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -81,7 +81,6 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); - split_k = std::max(1, split_k); using F32 = float; using F16 = ck::half_t; diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py new file mode 100644 index 0000000000..47135f3401 --- /dev/null +++ b/script/convert_miopen_driver_to_profiler.py @@ -0,0 +1,386 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Convert miopen driver command to ck Profiler +# Example: python3 ../script/convert_miopen_driver_to_profiler.py +# /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3 +# -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -m conv -g 32 -F 1 -t 1 + +import argparse +import subprocess + + +def init_const_args(args): + args.ck_profiler_cmd = '../build/bin/ckProfiler' + # use decimal values + args.init_method = 2 + # don't print tensor values + args.log_value = 0 + + +def run_ck_profiler_cmd(cmd): + print("ckProfiler command:") + print(cmd) + subprocess.run(cmd) + + +def parse_data_type(args): + if args.data_type == "fp32": + if args.ck_profier_op == "grouped_conv_bwd_weight" or \ + args.ck_profier_op == "grouped_conv_bwd_weight" or \ + args.ck_profier_op == "grouped_conv_fwd": + args.data_type = 0 + if args.data_type == "fp16": + if args.ck_profier_op == "grouped_conv_bwd_weight" or \ + args.ck_profier_op == "grouped_conv_bwd_data" or \ + args.ck_profier_op == "grouped_conv_fwd": + args.data_type = 1 + if args.data_type == "int8": + if args.ck_profier_op == "grouped_conv_bwd_weight": + args.data_type = 4 + if args.ck_profier_op == "grouped_conv_bwd_data": + print('Not supported data type for grouped_conv_bwd_data') + exit(1) + if args.ck_profier_op == "grouped_conv_fwd": + args.data_type = 3 + if args.data_type == "bfp16": + if args.ck_profier_op == "grouped_conv_bwd_weight" or \ + args.ck_profier_op == "grouped_conv_bwd_data" or \ + args.ck_profier_op == "grouped_conv_fwd": + args.data_type = 2 + + +def add_conv_params_to_cmd(args, cmd): + if args.spatial_dim == 1: + cmd += [str(args.fil_w), str(args.in_w)] + cmd += [str(args.conv_stride_w), str(args.dilation_w)] + cmd += [str(args.pad_w), str(args.pad_w)] + elif args.spatial_dim == 2: + cmd += [str(args.fil_h), str(args.fil_w)] + cmd += [str(args.in_h), str(args.in_w)] + cmd += [str(args.conv_stride_h), str(args.conv_stride_w)] + cmd += [str(args.dilation_h), str(args.dilation_w)] + cmd += [str(args.pad_h), str(args.pad_w)] + cmd += [str(args.pad_h), str(args.pad_w)] + elif args.spatial_dim == 3: + cmd += [str(args.fil_d), str(args.fil_h), str(args.fil_w)] + cmd += [str(args.in_d), str(args.in_h), str(args.in_w)] + cmd += [str(args.conv_stride_d), str(args.conv_stride_h)] + cmd += [str(args.conv_stride_w)] + cmd += [str(args.dilation_d), + str(args.dilation_h), + str(args.dilation_w)] + cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)] + cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)] + else: + print('Not supported spatial dim (supported: 1, 2, 3)') + exit(1) + + +def run_ck_grouped_conv_fwd(args): + args.ck_profier_op = "grouped_conv_fwd" + parse_data_type(args) + # default for MIOpen NHWGC + args.layout = 1 + # use int32 by default + args.index_type = 0 + + cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] + cmd += [str(args.data_type), str(args.layout), str(args.index_type)] + cmd += [str(args.verify), str(args.init_method)] + cmd += [str(args.log_value), str(args.time)] + cmd += [str(args.spatial_dim), str(args.group_count)] + cmd += [str(args.batchsize), str(args.out_channels)] + cmd += [str(args.in_channels)] + add_conv_params_to_cmd(args, cmd) + + run_ck_profiler_cmd(cmd) + + +def run_ck_grouped_conv_bwd_data(args): + args.ck_profier_op = "grouped_conv_bwd_data" + parse_data_type(args) + # default for MIOpen NHWGC + args.layout = 1 + + cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] + cmd += [str(args.data_type), str(args.layout)] + cmd += [str(args.verify), str(args.init_method)] + cmd += [str(args.log_value), str(args.time)] + cmd += [str(args.spatial_dim), str(args.group_count)] + cmd += [str(args.batchsize), str(args.out_channels)] + cmd += [str(args.in_channels)] + add_conv_params_to_cmd(args, cmd) + + run_ck_profiler_cmd(cmd) + + +def run_ck_grouped_conv_bwd_weight(args): + args.ck_profier_op = "grouped_conv_bwd_weight" + parse_data_type(args) + # default for MIOpen NHWGC + args.layout = 2 + # Test all split K value from the list {1, 2, 4, 8, 32, 64, 128} + args.split_k_value = -1 + + cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] + cmd += [str(args.data_type), str(args.layout)] + cmd += [str(args.verify), str(args.init_method)] + cmd += [str(args.log_value), str(args.time)] + cmd += [str(args.spatial_dim), str(args.group_count)] + cmd += [str(args.batchsize), str(args.out_channels)] + cmd += [str(args.in_channels)] + add_conv_params_to_cmd(args, cmd) + + cmd += [str(args.split_k_value)] + run_ck_profiler_cmd(cmd) + +# Get name of miopen driver, remove it from unknown +def process_miopen_driver_name(args, unknown): + if "convint8" in unknown: + args.data_type = 'int8' + unknown.remove("convint8") + elif "convbfp16" in unknown: + args.data_type = 'bfp16' + unknown.remove("convbfp16") + elif "convfp16" in unknown: + args.data_type = 'fp16' + unknown.remove("convfp16") + elif "conv" in unknown: + args.data_type = 'fp32' + unknown.remove("conv") + else: + print('Not supported driver (supported: conv, convfp16, convint8,' + ' convbfp16).') + exit(1) + + +def run_ck_profiler(args): + # MIOpen get number of channel per all groups, CK profiler get number of + # channel per group + args.in_channels = int(args.in_channels / args.group_count) + args.out_channels = int(args.out_channels / args.group_count) + + if args.forw == 0 or args.forw == 1 or args.forw == 3 or args.forw == 5: + run_ck_grouped_conv_fwd(args) + if args.forw == 0 or args.forw == 2 or args.forw == 3 or args.forw == 6: + run_ck_grouped_conv_bwd_data(args) + if args.forw == 0 or args.forw == 4 or args.forw == 5 or args.forw == 6: + run_ck_grouped_conv_bwd_weight(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="converter", + description="Convert miopen driver command to ck Profiler" + "\nExample: python3 " + "../script/convert_miopen_driver_to_profiler.py " + "/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 " + "-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g " + "32 -F 1 -t 1", + ) + parser.add_argument( + "-in_layout", + "-I", + default=-1, + type=int, + required=False, + help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)" + ) + parser.add_argument( + "-forw", + "-F", + default=0, + type=int, + required=False, + help="Flag enables fwd, bwd, wrw convolutions" + "\n0 fwd+bwd+wrw (default)" + "\n1 fwd only" + "\n2 bwd only" + "\n4 wrw only" + "\n3 fwd+bwd" + "\n5 fwd+wrw" + "\n6 bwd+wrw" + ) + parser.add_argument( + "-spatial_dim", + "-_", + default=2, + type=int, + required=False, + help="convolution spatial dimension (Default-2)" + ) + parser.add_argument( + "-batchsize", + "-n", + default=100, + type=int, + required=False, + help="Mini-batch size (Default=100)" + ) + parser.add_argument( + "-in_channels", + "-c", + default=3, + type=int, + required=False, + help="Number of Input Channels (Default=3)" + ) + parser.add_argument( + "-in_d", + "-!", + default=32, + type=int, + required=False, + help="Input Depth (Default=32)" + ) + parser.add_argument( + "-in_h", + "-H", + default=32, + type=int, + required=False, + help="Input Height (Default=32)" + ) + parser.add_argument( + "-in_w", + "-W", + default=32, + type=int, + required=False, + help="Input Width (Default=32)" + ) + parser.add_argument( + "-out_channels", + "-k", + default=32, + type=int, + required=False, + help="Number of Output Channels (Default=32)" + ) + parser.add_argument( + "-fil_d", + "-@", + default=3, + type=int, + required=False, + help="Filter Depth (Default=3)" + ) + parser.add_argument( + "-fil_h", + "-y", + default=3, + type=int, + required=False, + help="Filter Height (Default=3)" + ) + parser.add_argument( + "-fil_w", + "-x", + default=3, + type=int, + required=False, + help="Filter Width (Default=3)" + ) + parser.add_argument( + "-conv_stride_d", + "-#", + default=1, + type=int, + required=False, + help="Convolution Stride for Depth (Default=1)" + ) + parser.add_argument( + "-conv_stride_h", + "-u", + default=1, + type=int, + required=False, + help="Convolution Stride for Height (Default=1)" + ) + parser.add_argument( + "-conv_stride_w", + "-v", + default=1, + type=int, + required=False, + help="Convolution Stride for Width (Default=1)" + ) + parser.add_argument( + "-pad_d", + "-$", + default=1, + type=int, + required=False, + help="Zero Padding for Depth (Default=0)" + ) + parser.add_argument( + "-pad_h", + "-p", + default=1, + type=int, + required=False, + help="Zero Padding for Height (Default=0)" + ) + parser.add_argument( + "-pad_w", + "-q", + default=1, + type=int, + required=False, + help="Zero Padding for Width (Default=0)" + ) + parser.add_argument( + "-verify", + "-V", + default=1, + type=int, + required=False, + help="Verify Each Layer (Default=1)" + ) + parser.add_argument( + "-time", + "-t", + default=0, + type=int, + required=False, + help="Time Each Layer (Default=0)" + ) + parser.add_argument( + "-dilation_d", + "-^", + default=1, + type=int, + required=False, + help="Dilation of Filter Depth (Default=1)" + ) + parser.add_argument( + "-dilation_h", + "-l", + default=1, + type=int, + required=False, + help="Dilation of Filter Height (Default=1)" + ) + parser.add_argument( + "-dilation_w", + "-j", + default=1, + type=int, + required=False, + help="Dilation of Filter Width (Default=1)" + ) + parser.add_argument( + "-group_count", + "-g", + type=int, + default=1, + required=False, + help="Number of Groups (Default=1)" + ) + + args, unknown = parser.parse_known_args() + init_const_args(args) + process_miopen_driver_name(args, unknown) + print("Ignored args:") + print(unknown) + run_ck_profiler(args)