mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
refactor blockgemm change, isolate to v2;
This commit is contained in:
@@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
pt_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
@@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
dst_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
// can remove all bank conflicts, but drop the performance for some cases
|
||||
// Probably it is limited by compiler optimization.
|
||||
@@ -512,8 +510,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
||||
GemmLoopOrder::MNK>;
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
@@ -525,13 +522,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm>;
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -546,8 +544,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
|
||||
GemmLoopOrder::KMN>;
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
@@ -567,13 +564,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
WarpGemm,
|
||||
GemmLoopOrder::KMN>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -42,8 +42,6 @@ struct BlockGemmARegBRegCRegV1
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = Problem::BlockGemmLoopOrder;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
@@ -54,9 +52,8 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
@@ -89,36 +86,17 @@ struct BlockGemmARegBRegCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,33 +118,17 @@ struct BlockGemmARegBRegCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,82 +213,40 @@ struct BlockGemmARegBRegCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
|
||||
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
@@ -0,0 +1,372 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This BlockGemm enhanced the control over inst issue order
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct BlockGemmARegBRegCRegV2
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = Policy::BlockGemmLoopOrder;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GemmLoopOrder
|
||||
{
|
||||
KMN,
|
||||
MNK,
|
||||
};
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||
struct BlockGemmARegBRegCRegV2CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,8 +13,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct BlockGemmProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -23,9 +21,8 @@ struct BlockGemmProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -39,12 +39,6 @@ enum struct TailNumber
|
||||
Full,
|
||||
};
|
||||
|
||||
enum struct GemmLoopOrder
|
||||
{
|
||||
KMN,
|
||||
MNK,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)
|
||||
|
||||
@@ -14,11 +14,10 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
@@ -46,10 +45,9 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
// In the base situation, the Preshuffle setting should be false.
|
||||
static constexpr bool Preshuffle = false;
|
||||
@@ -169,11 +167,10 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
@@ -182,22 +179,20 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
ComputeDataType_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_,
|
||||
BlockGemmLoopOrder_>;
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
@@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr bool Preshuffle = Traits::Preshuffle;
|
||||
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
Reference in New Issue
Block a user