mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
refactor blockgemm change, isolate to v2;
This commit is contained in:
@@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
pt_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
@@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
dst_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
// can remove all bank conflicts, but drop the performance for some cases
|
||||
// Probably it is limited by compiler optimization.
|
||||
@@ -512,8 +510,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
||||
GemmLoopOrder::MNK>;
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
@@ -525,13 +522,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm>;
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -546,8 +544,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
|
||||
GemmLoopOrder::KMN>;
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
@@ -567,13 +564,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
WarpGemm,
|
||||
GemmLoopOrder::KMN>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user