mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE] Block universal gemm lds<->vgpr optimizations (#1906)
* [CK TILE] Block universal gemm lds<->vgpr optimizations * Rebase * Fixes
This commit is contained in:
@@ -68,16 +68,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
using AWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::AWarpDstrEncoding{}))>;
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
AWarpTileDistr{}))>;
|
||||
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
@@ -108,6 +98,25 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
@@ -116,18 +125,65 @@ struct BlockUniversalGemmAsBsCr
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
|
||||
WarpTile& warp_tile)
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size =
|
||||
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
@@ -144,6 +200,17 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
@@ -158,114 +225,39 @@ struct BlockUniversalGemmAsBsCr
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AWarpTensor a_warp_tile;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
BWarpTensor b_warp_tile;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -275,7 +267,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -291,149 +283,68 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
|
||||
[[maybe_unused]] ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
@@ -441,9 +352,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kIter],
|
||||
b_warp_tiles_[nIter][kIter]);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -468,126 +377,53 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
a_lds_load_tile_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
b_lds_load_tile_distr);
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_lds_gemm_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -600,13 +436,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
|
||||
@@ -626,7 +455,21 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
@@ -651,9 +494,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kInnerIter],
|
||||
b_warp_tiles_[nIter][kInnerIter]);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
|
||||
Reference in New Issue
Block a user