mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor blockgemm change, isolate to v2;
This commit is contained in:
@@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
tuple<sequence<0, 1>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
tuple<sequence<1, 0>>,
|
||||||
sequence<2, 1>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{};
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
@@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||||
|
|
||||||
pt_out.set_y_sliced_thread_data(
|
pt_out.set_y_sliced_thread_data(
|
||||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||||
pt_warp_tensor.get_thread_buffer());
|
pt_warp_tensor.get_thread_buffer());
|
||||||
});
|
});
|
||||||
@@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
|||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||||
|
|
||||||
dst_out.set_y_sliced_thread_data(
|
dst_out.set_y_sliced_thread_data(
|
||||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||||
dst_warp_tensor.get_thread_buffer());
|
dst_warp_tensor.get_thread_buffer());
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,13 +5,11 @@
|
|||||||
|
|
||||||
#include "ck_tile/core.hpp"
|
#include "ck_tile/core.hpp"
|
||||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
||||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
|
||||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||||
|
|
||||||
// can remove all bank conflicts, but drop the performance for some cases
|
// can remove all bank conflicts, but drop the performance for some cases
|
||||||
// Probably it is limited by compiler optimization.
|
// Probably it is limited by compiler optimization.
|
||||||
@@ -512,8 +510,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
|||||||
Problem::BlockFmhaShape::kN0,
|
Problem::BlockFmhaShape::kN0,
|
||||||
Problem::BlockFmhaShape::kK0>,
|
Problem::BlockFmhaShape::kK0>,
|
||||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||||
GemmLoopOrder::MNK>;
|
|
||||||
|
|
||||||
using WarpGemm =
|
using WarpGemm =
|
||||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||||
@@ -525,13 +522,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
|||||||
true>;
|
true>;
|
||||||
|
|
||||||
using BlockGemmPolicy =
|
using BlockGemmPolicy =
|
||||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||||
typename Problem::KDataType,
|
typename Problem::KDataType,
|
||||||
typename Problem::SaccDataType,
|
typename Problem::SaccDataType,
|
||||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||||
WarpGemm>;
|
WarpGemm,
|
||||||
|
GemmLoopOrder::MNK>;
|
||||||
|
|
||||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Problem>
|
template <typename Problem>
|
||||||
@@ -546,8 +544,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
|||||||
Problem::BlockFmhaShape::kN1,
|
Problem::BlockFmhaShape::kN1,
|
||||||
Problem::BlockFmhaShape::kK1>,
|
Problem::BlockFmhaShape::kK1>,
|
||||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
|
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||||
GemmLoopOrder::KMN>;
|
|
||||||
|
|
||||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||||
typename Problem::PDataType,
|
typename Problem::PDataType,
|
||||||
@@ -567,13 +564,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
|||||||
: WGAttrNumAccessEnum::Single>;
|
: WGAttrNumAccessEnum::Single>;
|
||||||
|
|
||||||
using BlockGemmPolicy =
|
using BlockGemmPolicy =
|
||||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||||
typename Problem::VDataType,
|
typename Problem::VDataType,
|
||||||
typename Problem::OaccDataType,
|
typename Problem::OaccDataType,
|
||||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||||
WarpGemm>;
|
WarpGemm,
|
||||||
|
GemmLoopOrder::KMN>;
|
||||||
|
|
||||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Problem>
|
template <typename Problem>
|
||||||
|
|||||||
@@ -42,8 +42,6 @@ struct BlockGemmARegBRegCRegV1
|
|||||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||||
|
|
||||||
static constexpr auto BlockGemmLoopOrder = Problem::BlockGemmLoopOrder;
|
|
||||||
|
|
||||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -54,9 +52,8 @@ struct BlockGemmARegBRegCRegV1
|
|||||||
|
|
||||||
using Traits = GemmTraits_<Problem, Policy>;
|
using Traits = GemmTraits_<Problem, Policy>;
|
||||||
|
|
||||||
using WarpGemm = typename Traits::WarpGemm;
|
using WarpGemm = typename Traits::WarpGemm;
|
||||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||||
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
|
||||||
|
|
||||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||||
@@ -89,36 +86,17 @@ struct BlockGemmARegBRegCRegV1
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
{
|
sequence<NWarp>,
|
||||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
sequence<NWarp>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<1, 0>>,
|
||||||
tuple<sequence<1, 0>>,
|
sequence<1, 2>,
|
||||||
tuple<sequence<1, 0>>,
|
sequence<0, 0>>{};
|
||||||
sequence<2, 1>,
|
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
sequence<0, 0>>{};
|
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||||
|
|
||||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
return a_block_dstr_encode;
|
||||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
||||||
|
|
||||||
return a_block_dstr_encode;
|
|
||||||
}
|
|
||||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
|
||||||
{
|
|
||||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
|
||||||
sequence<NWarp>,
|
|
||||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
|
||||||
tuple<sequence<1, 0>>,
|
|
||||||
tuple<sequence<1, 0>>,
|
|
||||||
sequence<1, 2>,
|
|
||||||
sequence<0, 0>>{};
|
|
||||||
|
|
||||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
|
||||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
|
||||||
|
|
||||||
return a_block_dstr_encode;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,33 +118,17 @@ struct BlockGemmARegBRegCRegV1
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
{
|
sequence<MWarp>,
|
||||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
sequence<MWarp>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
tuple<sequence<0, 1>>,
|
||||||
tuple<sequence<0, 1>>,
|
sequence<1, 2>,
|
||||||
tuple<sequence<0, 1>>,
|
sequence<0, 0>>{};
|
||||||
sequence<2, 1>,
|
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
sequence<0, 0>>{};
|
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
|
||||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
||||||
|
|
||||||
return b_block_dstr_encode;
|
return b_block_dstr_encode;
|
||||||
}
|
|
||||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
|
||||||
{
|
|
||||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
|
||||||
sequence<MWarp>,
|
|
||||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
|
||||||
tuple<sequence<0, 1>>,
|
|
||||||
tuple<sequence<0, 1>>,
|
|
||||||
sequence<1, 2>,
|
|
||||||
sequence<0, 0>>{};
|
|
||||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
|
||||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
|
||||||
return b_block_dstr_encode;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,82 +213,40 @@ struct BlockGemmARegBRegCRegV1
|
|||||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||||
|
|
||||||
// hot loop:
|
// hot loop:
|
||||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||||
{
|
|
||||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
|
||||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
|
||||||
// read A warp tensor from A Block window
|
|
||||||
AWarpTensor a_warp_tensor;
|
|
||||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
|
||||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
|
||||||
|
|
||||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
|
||||||
// read B warp tensor from B block tensor
|
|
||||||
BWarpTensor b_warp_tensor;
|
|
||||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
|
||||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
|
||||||
|
|
||||||
// read C warp tensor from C block tensor
|
|
||||||
using c_iter_idx = std::conditional_t<TransposeC,
|
|
||||||
sequence<nIter, mIter>,
|
|
||||||
sequence<mIter, nIter>>;
|
|
||||||
CWarpTensor c_warp_tensor;
|
|
||||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
|
||||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
|
||||||
|
|
||||||
// warp GEMM
|
|
||||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
|
||||||
|
|
||||||
// write C warp tensor into C block tensor
|
|
||||||
c_block_tensor.set_y_sliced_thread_data(
|
|
||||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
|
||||||
c_warp_tensor.get_thread_buffer());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
|
||||||
{
|
|
||||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||||
|
// read A warp tensor from A Block window
|
||||||
|
AWarpTensor a_warp_tensor;
|
||||||
|
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||||
|
|
||||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
// read B warp tensor from B block tensor
|
||||||
// read A warp tensor from A Block window
|
BWarpTensor b_warp_tensor;
|
||||||
AWarpTensor a_warp_tensor;
|
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||||
|
|
||||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
// read C warp tensor from C block tensor
|
||||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
using c_iter_idx = std::
|
||||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||||
|
CWarpTensor c_warp_tensor;
|
||||||
|
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||||
|
|
||||||
// read B warp tensor from B block tensor
|
// warp GEMM
|
||||||
BWarpTensor b_warp_tensor;
|
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||||
|
|
||||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
// write C warp tensor into C block tensor
|
||||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
c_block_tensor.set_y_sliced_thread_data(
|
||||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||||
// read C warp tensor from C block tensor
|
c_warp_tensor.get_thread_buffer());
|
||||||
CWarpTensor c_warp_tensor;
|
|
||||||
|
|
||||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
|
||||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
|
||||||
|
|
||||||
// warp GEMM
|
|
||||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
|
||||||
|
|
||||||
// write C warp tensor into C block tensor
|
|
||||||
c_block_tensor.set_y_sliced_thread_data(
|
|
||||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
|
||||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
|
||||||
c_warp_tensor.get_thread_buffer());
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||||
|
|||||||
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ck_tile/core.hpp"
|
||||||
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||||
|
|
||||||
|
namespace ck_tile {
|
||||||
|
|
||||||
|
// This BlockGemm enhanced the control over inst issue order
|
||||||
|
// A is block distributed tensor
|
||||||
|
// B is block distributed tensor
|
||||||
|
// C is block distributed tensor
|
||||||
|
template <typename Problem_, typename Policy_>
|
||||||
|
struct BlockGemmARegBRegCRegV2
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||||
|
struct GemmTraits_
|
||||||
|
{
|
||||||
|
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||||
|
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||||
|
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||||
|
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||||
|
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||||
|
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||||
|
|
||||||
|
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||||
|
|
||||||
|
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||||
|
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||||
|
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||||
|
|
||||||
|
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||||
|
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||||
|
|
||||||
|
static constexpr index_t MWarp = config.template at<1>();
|
||||||
|
static constexpr index_t NWarp = config.template at<2>();
|
||||||
|
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||||
|
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||||
|
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||||
|
|
||||||
|
static constexpr auto BlockGemmLoopOrder = Policy::BlockGemmLoopOrder;
|
||||||
|
|
||||||
|
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
using Problem = remove_cvref_t<Problem_>;
|
||||||
|
using Policy = remove_cvref_t<Policy_>;
|
||||||
|
|
||||||
|
using Traits = GemmTraits_<Problem, Policy>;
|
||||||
|
|
||||||
|
using WarpGemm = typename Traits::WarpGemm;
|
||||||
|
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||||
|
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
||||||
|
|
||||||
|
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||||
|
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||||
|
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||||
|
|
||||||
|
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||||
|
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||||
|
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||||
|
|
||||||
|
static constexpr index_t MWarp = Traits::MWarp;
|
||||||
|
static constexpr index_t NWarp = Traits::NWarp;
|
||||||
|
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||||
|
|
||||||
|
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||||
|
{
|
||||||
|
if constexpr(UseDefaultScheduler)
|
||||||
|
{
|
||||||
|
constexpr auto a_block_outer_dstr_encoding =
|
||||||
|
tile_distribution_encoding<sequence<NWarp>,
|
||||||
|
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<>,
|
||||||
|
tuple<>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
|
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return a_block_dstr_encode;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||||
|
{
|
||||||
|
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<NWarp>,
|
||||||
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<sequence<1, 0>>,
|
||||||
|
tuple<sequence<1, 0>>,
|
||||||
|
sequence<2, 1>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
|
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return a_block_dstr_encode;
|
||||||
|
}
|
||||||
|
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||||
|
{
|
||||||
|
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<NWarp>,
|
||||||
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<sequence<1, 0>>,
|
||||||
|
tuple<sequence<1, 0>>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
|
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return a_block_dstr_encode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||||
|
{
|
||||||
|
if constexpr(UseDefaultScheduler)
|
||||||
|
{
|
||||||
|
constexpr auto b_block_outer_dstr_encoding =
|
||||||
|
tile_distribution_encoding<sequence<MWarp>,
|
||||||
|
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<>,
|
||||||
|
tuple<>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return b_block_dstr_encode;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||||
|
{
|
||||||
|
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<MWarp>,
|
||||||
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<sequence<0, 1>>,
|
||||||
|
tuple<sequence<0, 1>>,
|
||||||
|
sequence<2, 1>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return b_block_dstr_encode;
|
||||||
|
}
|
||||||
|
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||||
|
{
|
||||||
|
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<MWarp>,
|
||||||
|
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||||
|
tuple<sequence<0, 1>>,
|
||||||
|
tuple<sequence<0, 1>>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||||
|
return b_block_dstr_encode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||||
|
{
|
||||||
|
if constexpr(UseDefaultScheduler)
|
||||||
|
{
|
||||||
|
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<MWarp>,
|
||||||
|
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||||
|
tuple<>,
|
||||||
|
tuple<>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return c_block_dstr_encode;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<>,
|
||||||
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||||
|
tuple<sequence<1, 2>>,
|
||||||
|
tuple<sequence<1, 1>>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||||
|
|
||||||
|
return c_block_dstr_encode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// C += A * B
|
||||||
|
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||||
|
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||||
|
const ABlockTensor& a_block_tensor,
|
||||||
|
const BBlockTensor& b_block_tensor) const
|
||||||
|
{
|
||||||
|
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||||
|
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||||
|
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||||
|
"wrong!");
|
||||||
|
|
||||||
|
// check ABC-block-distribution
|
||||||
|
static_assert(
|
||||||
|
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||||
|
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||||
|
.get_static_tile_distribution_encoding())>>,
|
||||||
|
"A distribution is wrong!");
|
||||||
|
static_assert(
|
||||||
|
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||||
|
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||||
|
.get_static_tile_distribution_encoding())>>,
|
||||||
|
"B distribution is wrong!");
|
||||||
|
static_assert(
|
||||||
|
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||||
|
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||||
|
.get_static_tile_distribution_encoding())>>,
|
||||||
|
"C distribution is wrong!");
|
||||||
|
|
||||||
|
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||||
|
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||||
|
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||||
|
|
||||||
|
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||||
|
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||||
|
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||||
|
|
||||||
|
constexpr auto a_warp_y_lengths =
|
||||||
|
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||||
|
constexpr auto b_warp_y_lengths =
|
||||||
|
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||||
|
constexpr auto c_warp_y_lengths =
|
||||||
|
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||||
|
|
||||||
|
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||||
|
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||||
|
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||||
|
|
||||||
|
// hot loop:
|
||||||
|
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||||
|
{
|
||||||
|
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||||
|
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||||
|
// read A warp tensor from A Block window
|
||||||
|
AWarpTensor a_warp_tensor;
|
||||||
|
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||||
|
|
||||||
|
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||||
|
// read B warp tensor from B block tensor
|
||||||
|
BWarpTensor b_warp_tensor;
|
||||||
|
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||||
|
|
||||||
|
CWarpTensor c_warp_tensor;
|
||||||
|
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||||
|
|
||||||
|
// warp GEMM
|
||||||
|
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||||
|
|
||||||
|
// write C warp tensor into C block tensor
|
||||||
|
c_block_tensor.set_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||||
|
c_warp_tensor.get_thread_buffer());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||||
|
{
|
||||||
|
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||||
|
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||||
|
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||||
|
// read A warp tensor from A Block window
|
||||||
|
AWarpTensor a_warp_tensor;
|
||||||
|
|
||||||
|
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||||
|
|
||||||
|
// read B warp tensor from B block tensor
|
||||||
|
BWarpTensor b_warp_tensor;
|
||||||
|
|
||||||
|
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||||
|
|
||||||
|
// read C warp tensor from C block tensor
|
||||||
|
CWarpTensor c_warp_tensor;
|
||||||
|
|
||||||
|
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||||
|
|
||||||
|
// warp GEMM
|
||||||
|
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||||
|
|
||||||
|
// write C warp tensor into C block tensor
|
||||||
|
c_block_tensor.set_y_sliced_thread_data(
|
||||||
|
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||||
|
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||||
|
c_warp_tensor.get_thread_buffer());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||||
|
{
|
||||||
|
if constexpr(UseDefaultScheduler)
|
||||||
|
{
|
||||||
|
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<MWarp>,
|
||||||
|
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||||
|
tuple<>,
|
||||||
|
tuple<>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
|
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||||
|
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||||
|
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||||
|
return c_block_tensor;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||||
|
sequence<>,
|
||||||
|
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||||
|
tuple<sequence<1, 2>>,
|
||||||
|
tuple<sequence<1, 1>>,
|
||||||
|
sequence<1, 2>,
|
||||||
|
sequence<0, 0>>{};
|
||||||
|
|
||||||
|
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||||
|
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||||
|
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||||
|
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||||
|
return c_block_tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// C = A * B
|
||||||
|
template <typename ABlockTensor, typename BBlockTensor>
|
||||||
|
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||||
|
const BBlockTensor& b_block_tensor) const
|
||||||
|
{
|
||||||
|
auto c_block_tensor = MakeCBlockTile();
|
||||||
|
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||||
|
return c_block_tensor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ck_tile
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ck_tile/core.hpp"
|
||||||
|
|
||||||
|
namespace ck_tile {
|
||||||
|
|
||||||
|
enum struct GemmLoopOrder
|
||||||
|
{
|
||||||
|
KMN,
|
||||||
|
MNK,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AType_,
|
||||||
|
typename BType_,
|
||||||
|
typename CType_,
|
||||||
|
typename BlockWarps_,
|
||||||
|
typename WarpGemm_,
|
||||||
|
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||||
|
struct BlockGemmARegBRegCRegV2CustomPolicy
|
||||||
|
{
|
||||||
|
using AType = remove_cvref_t<AType_>;
|
||||||
|
using BType = remove_cvref_t<BType_>;
|
||||||
|
using CType = remove_cvref_t<CType_>;
|
||||||
|
|
||||||
|
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||||
|
|
||||||
|
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||||
|
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||||
|
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||||
|
|
||||||
|
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||||
|
|
||||||
|
static constexpr auto BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||||
|
|
||||||
|
template <typename Problem>
|
||||||
|
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||||
|
{
|
||||||
|
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ck_tile
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ck_tile/core.hpp"
|
#include "ck_tile/core.hpp"
|
||||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
|
||||||
|
|
||||||
namespace ck_tile {
|
namespace ck_tile {
|
||||||
|
|
||||||
@@ -14,8 +13,7 @@ template <typename ADataType_,
|
|||||||
typename CDataType_,
|
typename CDataType_,
|
||||||
index_t kBlockSize_,
|
index_t kBlockSize_,
|
||||||
typename BlockGemmShape_,
|
typename BlockGemmShape_,
|
||||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN,
|
index_t NumWaveGroups_ = 1>
|
||||||
index_t NumWaveGroups_ = 1>
|
|
||||||
struct BlockGemmProblem
|
struct BlockGemmProblem
|
||||||
{
|
{
|
||||||
using ADataType = remove_cvref_t<ADataType_>;
|
using ADataType = remove_cvref_t<ADataType_>;
|
||||||
@@ -23,9 +21,8 @@ struct BlockGemmProblem
|
|||||||
using CDataType = remove_cvref_t<CDataType_>;
|
using CDataType = remove_cvref_t<CDataType_>;
|
||||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||||
|
|
||||||
static constexpr index_t kBlockSize = kBlockSize_;
|
static constexpr index_t kBlockSize = kBlockSize_;
|
||||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|||||||
@@ -39,12 +39,6 @@ enum struct TailNumber
|
|||||||
Full,
|
Full,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum struct GemmLoopOrder
|
|
||||||
{
|
|
||||||
KMN,
|
|
||||||
MNK,
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)
|
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)
|
||||||
|
|||||||
@@ -14,11 +14,10 @@ template <typename ADataType_,
|
|||||||
typename CDataType_,
|
typename CDataType_,
|
||||||
typename BlockGemmShape_,
|
typename BlockGemmShape_,
|
||||||
typename Traits_,
|
typename Traits_,
|
||||||
typename ComputeDataType_ = ADataType_,
|
typename ComputeDataType_ = ADataType_,
|
||||||
bool FixedVectorSize_ = false,
|
bool FixedVectorSize_ = false,
|
||||||
index_t VectorSizeA_ = 1,
|
index_t VectorSizeA_ = 1,
|
||||||
index_t VectorSizeB_ = 1,
|
index_t VectorSizeB_ = 1>
|
||||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
|
||||||
struct GemmPipelineProblemBase
|
struct GemmPipelineProblemBase
|
||||||
{
|
{
|
||||||
using Traits = remove_cvref_t<Traits_>;
|
using Traits = remove_cvref_t<Traits_>;
|
||||||
@@ -46,10 +45,9 @@ struct GemmPipelineProblemBase
|
|||||||
static constexpr bool kPadN = Traits::kPadN;
|
static constexpr bool kPadN = Traits::kPadN;
|
||||||
static constexpr bool kPadK = Traits::kPadK;
|
static constexpr bool kPadK = Traits::kPadK;
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
|
||||||
|
|
||||||
// In the base situation, the Preshuffle setting should be false.
|
// In the base situation, the Preshuffle setting should be false.
|
||||||
static constexpr bool Preshuffle = false;
|
static constexpr bool Preshuffle = false;
|
||||||
@@ -169,11 +167,10 @@ template <typename ADataType_,
|
|||||||
typename CDataType_,
|
typename CDataType_,
|
||||||
typename BlockGemmShape_,
|
typename BlockGemmShape_,
|
||||||
typename Traits_,
|
typename Traits_,
|
||||||
typename ComputeDataType_ = ADataType_,
|
typename ComputeDataType_ = ADataType_,
|
||||||
bool FixedVectorSize_ = false,
|
bool FixedVectorSize_ = false,
|
||||||
index_t VectorSizeA_ = 1,
|
index_t VectorSizeA_ = 1,
|
||||||
index_t VectorSizeB_ = 1,
|
index_t VectorSizeB_ = 1>
|
||||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
|
||||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||||
BDataType_,
|
BDataType_,
|
||||||
CDataType_,
|
CDataType_,
|
||||||
@@ -182,22 +179,20 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
|||||||
ComputeDataType_,
|
ComputeDataType_,
|
||||||
FixedVectorSize_,
|
FixedVectorSize_,
|
||||||
VectorSizeA_,
|
VectorSizeA_,
|
||||||
VectorSizeB_,
|
VectorSizeB_>;
|
||||||
BlockGemmLoopOrder_>;
|
|
||||||
|
|
||||||
template <typename ADataType_,
|
template <typename ADataType_,
|
||||||
typename BDataType_,
|
typename BDataType_,
|
||||||
typename CDataType_,
|
typename CDataType_,
|
||||||
typename BlockGemmShape_,
|
typename BlockGemmShape_,
|
||||||
typename Traits_,
|
typename Traits_,
|
||||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||||
bool HasHotLoop_ = true,
|
bool HasHotLoop_ = true,
|
||||||
TailNumber TailNum_ = TailNumber::Full,
|
TailNumber TailNum_ = TailNumber::Full,
|
||||||
typename ComputeDataType_ = ADataType_,
|
typename ComputeDataType_ = ADataType_,
|
||||||
bool FixedVectorSize_ = false,
|
bool FixedVectorSize_ = false,
|
||||||
index_t VectorSizeA_ = 1,
|
index_t VectorSizeA_ = 1,
|
||||||
index_t VectorSizeB_ = 1,
|
index_t VectorSizeB_ = 1>
|
||||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
|
||||||
struct UniversalGemmPipelineProblem
|
struct UniversalGemmPipelineProblem
|
||||||
{
|
{
|
||||||
using Traits = remove_cvref_t<Traits_>;
|
using Traits = remove_cvref_t<Traits_>;
|
||||||
@@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem
|
|||||||
static constexpr auto Scheduler = Scheduler_;
|
static constexpr auto Scheduler = Scheduler_;
|
||||||
static constexpr bool Preshuffle = Traits::Preshuffle;
|
static constexpr bool Preshuffle = Traits::Preshuffle;
|
||||||
|
|
||||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
|
||||||
|
|
||||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||||
static constexpr auto TailNum = TailNum_;
|
static constexpr auto TailNum = TailNum_;
|
||||||
|
|||||||
Reference in New Issue
Block a user