refactor blockgemm change, isolate to v2;

This commit is contained in:
aska-0096
2025-08-12 14:25:50 +00:00
parent fd1eb323af
commit 0810799e25
8 changed files with 514 additions and 194 deletions

View File

@@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
@@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});

View File

@@ -5,13 +5,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
// can remove all bank conflicts, but drop the performance for some cases
// Probably it is limited by compiler optimization.
@@ -512,8 +510,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
GemmLoopOrder::MNK>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::QDataType,
@@ -525,13 +522,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
@@ -546,8 +544,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
GemmLoopOrder::KMN>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::PDataType,
@@ -567,13 +564,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
: WGAttrNumAccessEnum::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
WarpGemm,
GemmLoopOrder::KMN>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>