mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK TILE] Add gemm compute pipeline v3 (#1661)
* [CK TILE] Add gemm compute pipeline v3 * Enable universal gemm compute pipeline. * Rename example and add compute pipeline. * Introduce ag bg cr pipeline impl base. * Refactor to reuse code. * Cleaning * Formatting. --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
@@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}),
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}),
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}),
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
@@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
private:
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
@@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
@@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
@@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
@@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
@@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
});
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
@@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
@@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
@@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
@@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
@@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
@@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
@@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
});
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
@@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr
|
||||
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
@@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kIter],
|
||||
b_warp_tiles_[nIter][kIter]);
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kIter],
|
||||
b_warp_tiles_[nIter][kIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
@@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
@@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
@@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
@@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
@@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
});
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
@@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
@@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
@@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr
|
||||
// penalty
|
||||
if constexpr(kIter.value == KRepeat - 1 &&
|
||||
kInnerIter.value == KInnerLoopIter - 1 &&
|
||||
mIter.value == GemmTraits::MIterPerWarp - 1 &&
|
||||
nIter.value == GemmTraits::NIterPerWarp - 1)
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kInnerIter],
|
||||
b_warp_tiles_[nIter][kInnerIter]);
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kInnerIter],
|
||||
b_warp_tiles_[nIter][kInnerIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
|
||||
Reference in New Issue
Block a user