mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)
* Block universal gemm. * Universal block gemm with interwave scheduler - draft. * Refactoring * Move a/b_warp_tiles into BlockGemmImpl * set BlockGemmImpl as a class member * Change tile size for more suitable to memory bound cases. * Introduce kKPerThread to WarpGemm * Add documentation comment. * Fix Interwave scheduler block gemm. * Add compute/memory friendly tile configuration. * Clean * New tile configurations in gemm mem example. * Add more static checks and fix loop order in block gemm. * Add more static checks and use warp gemm mfma dispatcher. * Add default scheduler block gemm. * Remove logging in example.
This commit is contained in:
661
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
Normal file
661
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
Normal file
@@ -0,0 +1,661 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
private:
|
||||
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}),
|
||||
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}),
|
||||
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}),
|
||||
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
using AWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::AWarpDstrEncoding{}))>;
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
|
||||
using BWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
};
|
||||
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
private:
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
|
||||
{
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
});
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kIter],
|
||||
b_warp_tiles_[nIter][kIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
|
||||
{
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
|
||||
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
|
||||
// Would we need InterWaveSchedulingMacClusters > 1 ???
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<typename GemmTraits::ADataType,
|
||||
typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<typename GemmTraits::BDataType,
|
||||
typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
|
||||
GemmTraits::MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
|
||||
GemmTraits::NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
});
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC
|
||||
// cluster, but except the first, as we can shorten non-MAC cluster a bit
|
||||
// and there's no observable negative impact. The desired effect is waves in
|
||||
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
|
||||
// hijacking MAC resource from other workgroups and reducing the chance of
|
||||
// latency hiding by waiting for the rest of the workgroup at the eventual
|
||||
// sync point.
|
||||
if constexpr(kIter.value != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from
|
||||
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
|
||||
// by applying small delays to different wavefronts It is
|
||||
// performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(kIter.value == KRepeat - 1 &&
|
||||
kInnerIter.value == KInnerLoopIter - 1 &&
|
||||
mIter.value == GemmTraits::MIterPerWarp - 1 &&
|
||||
nIter.value == GemmTraits::NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
typename GemmTraits::WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kInnerIter],
|
||||
b_warp_tiles_[nIter][kInnerIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
|
||||
nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
private:
|
||||
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
|
||||
@@ -11,6 +11,7 @@ namespace ck_tile {
|
||||
|
||||
enum struct GemmPipelineScheduler
|
||||
{
|
||||
Default,
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
|
||||
default: os << "";
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// TODO: this 8 is AK1! should be a policy parameter!
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
@@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
@@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t N0 = BlockSize / get_warp_size();
|
||||
constexpr index_t N1 = NPerBlock / (N2 * N0);
|
||||
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
@@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
|
||||
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadN = GemmTraits::kPadN;
|
||||
static constexpr bool kPadK = GemmTraits::kPadK;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -21,9 +21,10 @@ struct WarpGemmAtrributeMfma
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -86,9 +87,10 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -197,9 +199,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -260,9 +263,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -330,9 +334,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -444,10 +449,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -583,10 +589,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,11 @@ struct WarpGemmImpl
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
/// @brief The number of elements in K dimension processed by single thread in wavefront.
|
||||
///
|
||||
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
|
||||
/// In such situation this value reflects this fact.
|
||||
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
|
||||
Reference in New Issue
Block a user