CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)

* Block universal gemm.

* Universal block gemm with interwave scheduler - draft.

* Refactoring

* Move a/b_warp_tiles into BlockGemmImpl
* set BlockGemmImpl as a class member

* Change tile size for more suitable to memory bound cases.

* Introduce kKPerThread to WarpGemm

* Add documentation comment.

* Fix Interwave scheduler block gemm.

* Add compute/memory friendly tile configuration.

* Clean

* New tile configurations in gemm mem example.

* Add more static checks and fix loop order in block gemm.

* Add more static checks and use warp gemm mfma dispatcher.

* Add default scheduler block gemm.

* Remove logging in example.
This commit is contained in:
Adam Osewski
2024-11-26 08:45:14 +01:00
committed by GitHub
parent 440e28b08f
commit b6bcd76d88
11 changed files with 779 additions and 56 deletions

View File

@@ -0,0 +1,661 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
using AWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::AWarpDstrEncoding{}))>;
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
using BWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
{
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window!
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_tiles_;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_tiles_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window!
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
});
});
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kIter],
b_warp_tiles_[nIter][kIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
GemmTraits::MIterPerWarp>
a_warp_tiles_;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
GemmTraits::NIterPerWarp>
b_warp_tiles_;
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
});
});
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == GemmTraits::MIterPerWarp - 1 &&
nIter.value == GemmTraits::NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kInnerIter],
b_warp_tiles_[nIter][kInnerIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window);
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
}
// C = A * B
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
return c_block_tensor;
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile

View File

@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
});
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else if constexpr(TailNum == TailNumber::Two)

View File

@@ -11,6 +11,7 @@ namespace ck_tile {
enum struct GemmPipelineScheduler
{
Default,
Intrawave,
Interwave,
};
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{
switch(s)
{
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
default: os << "";

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
@@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
@@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
@@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
@@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};

View File

@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -21,9 +21,10 @@ struct WarpGemmAtrributeMfma
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -86,9 +87,10 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -197,9 +199,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -260,9 +263,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -330,9 +334,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -444,10 +449,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -583,10 +589,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,6 +14,11 @@ struct WarpGemmImpl
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType;