adap gemm_mx_kernel.hpp from flatmm, comment changes needed to mx pipeline from flatmm

This commit is contained in:
Sami Remes
2025-12-18 04:06:04 -05:00
parent 292df2719f
commit 4985afb03c
6 changed files with 2566 additions and 0 deletions

View File

View File

@@ -0,0 +1,599 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
// Controls how many MAC clusters (MFMA blocks) we have per wave
// Ie if
// InterWaveSchedulingMacClusters = 1;
// KPerBlock == 32
// WarpGemm::kK = 8
// Then we would group all 4 WarpGemms into single MAC cluster.
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
// split those 4 warp gemms into two groups.
static constexpr index_t InterWaveSchedulingMacClusters = 1;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow&,
const BSmemBlockWindow&,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C = A * B
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
return c_block_tensor;
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile

View File

@@ -0,0 +1,381 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
namespace ck_tile {
template <typename ScaleM = MXScalePointer<-1>, typename ScaleN = MXScalePointer<-1>, index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct MXGemmKernelArgs : UniversalGemmKernelArgs<NumATensor, NumBTensor, NumDTensor>
{
using Base = UniversalGemmKernelArgs<NumATensor, NumBTensor, NumDTensor>;
CK_TILE_HOST MXGemmKernelArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: Base(as_ptr_,
bs_ptr_,
ds_ptr_,
e_ptr_,
k_batch_,
M_,
N_,
K_,
stride_As_,
stride_Bs_,
stride_Ds_,
stride_E_)
{
}
};
template <typename TilePartitioner_, typename MXGemmPipeline_, typename EpiloguePipeline_>
struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, EpiloguePipeline_>
{
using Underlying = UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using MXGemmPipeline = remove_cvref_t<MXGemmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename MXGemmPipeline::BlockGemmShape>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename MXGemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename MXGemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename MXGemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = MXGemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = MXGemmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename MXGemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename MXGemmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static constexpr auto I5 = number<5>();
static constexpr index_t NumATensor = typename Underlying::AsDataType::size();
static constexpr index_t NumBTensor = typename Underlying::BsDataType::size();
static constexpr index_t NumDTensor = typename Underlying::DsDataType::size();
using ADataType = remove_cvref_t<std::tuple_element_t<I0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<I0, BsDataType>>;
static constexpr auto MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
static constexpr auto NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
static constexpr auto KThreadPerXdl = 64 / MThreadPerXdl;
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr auto MXdlPack = MXGemmPipeline::MXdlPack;
static constexpr auto NXdlPack = MXGemmPipeline::NXdlPack;
static constexpr auto KXdlPack = MXGemmPipeline::KXdlPack;
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mx_gemm", gemm_prec_str<ADataType, BDataType>, MXGemmPipeline::GetName());
// clang-format on
}
template <typename ScaleM, typename ScaleN>
using KernelArgs = MXGemmKernelArgs<ScaleM, ScaleN, NumATensor, NumBTensor, NumDTensor>;
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const KernelArgs<ScaleM, ScaleN>& kargs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
hipGetErrorName(hipGetLastError()));
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
KernelBlockSize,
dync_smem_size) != hipSuccess)
throw std::runtime_error(
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
hipGetErrorName(hipGetLastError()));
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, 1);
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t k_size)
{
// Get tensor views from the UniversalGemmKernel
const auto& gemm_tensor_views_tuple =
Underlying::template MakeGemmTensorViews<DstInMemOp>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, k_size);
auto scale_a = kargs.scale_m_ptr;
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
// A scale tensor view
const auto& scale_a_tensor_view = [&]() {
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
}();
// B scale tensor view
const auto& scale_b_tensor_view = [&]() {
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_navie_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
}();
return concat_tuple(gemm_tensor_views_tuple, make_tuple(scale_a_tensor_view, scale_b_tensor_view));
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& padded_views = Underlying::template MakeGemmPadViews(views);
return make_tuple(
padded_views.at(I0), padded_views.at(I1), padded_views.at(I2), padded_views.at(I3), views.at(I4), views.at(I5));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& tile_windows = Underlying::template MakeGemmTileWindows(views, i_m, i_n);
static constexpr int BlockScaleSize = 32;
auto scale_a_block_window = make_tile_window(
views.at(I4),
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
{i_m / MXdlPack, 0});
auto scale_b_block_window = make_tile_window(
views.at(I5),
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
{i_n / NXdlPack, 0});
return make_tuple(tile_windows.at(I0),
tile_windows.at(I1),
tile_windows.at(I2),
tile_windows.at(I3),
scale_a_block_window,
scale_b_block_window);
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunMxGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs<ScaleM, ScaleN>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
constexpr bool DoEpiScale =
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
{
return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return MXGemmPipeline::GetSmemSize();
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(KernelArgs<ScaleM, ScaleN> kargs,
int partition_idx = get_block_id()) const
{
const int total_work_tile_cnt = amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
do
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const auto a_ptr = static_cast<const ADataType*>(kargs.as_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const auto b_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// options
std::array<const ADataType*, NumATensor> as_ptr;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
splitk_batch_offset.as_k_split_offset[i] / APackedSize;
});
std::array<const BDataType*, NumBTensor> bs_ptr;
static_for<0, NumBTensor, 1>{}([&](auto i) {
bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
});
// Calculate output offset from tile partitioner and apply to output pointer
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
if constexpr(has_tile_partitioner_output_offset)
{
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
e_ptr += output_offset;
}
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (MXGemmPipeline::NumWaveGroups == 1);
RunMxGemm<ScaleM, ScaleN, scheduler_type>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
static_assert(false,
"Unimplemented: atomic_add with odd vector size for fp16/bf16");
}
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,110 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <int SharedGranularityMN, int SharedGranularityK = 0>
struct MXScalePointer
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
CK_TILE_HOST_DEVICE MXScalePointer() = default;
CK_TILE_HOST_DEVICE MXScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE MXScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const
{
MXScalePointer ret;
if constexpr(GranularityMN == 0)
{
ret.ptr = ptr + offset / GranularityK;
}
else
{
ret.ptr = ptr + offset / GranularityMN / GranularityK;
}
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
};
template <int SharedGranularityMN>
struct MXScalePointer<SharedGranularityMN, 0>
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const float* ptr;
index_t length;
CK_TILE_HOST_DEVICE MXScalePointer() = default;
CK_TILE_HOST_DEVICE MXScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE MXScalePointer(const float* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const
{
MXScalePointer ret;
if constexpr(GranularityMN == 1)
{
ret.ptr = ptr + offset;
ret.length = length - offset;
}
else
{
ret.ptr = ptr + offset / GranularityMN;
ret.length = length - offset / GranularityMN;
}
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
return i < length ? ptr[i] : 0;
else
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
}
};
// shared granularityMN = -1 means no scale
template <>
struct MXScalePointer<-1, 0>
{
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const float* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr MXScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr MXScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr MXScalePointer(const float*, index_t) {}
CK_TILE_HOST_DEVICE constexpr MXScalePointer operator+(index_t) const
{
return MXScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,379 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
template <typename Problem>
struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr index_t DWORDx4 = 16;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
using TileShape = typename Problem::BlockGemmShape;
using BlockWarps = typename TileShape::BlockWarps;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t WaveNum = BlockSize / WaveSize;
static constexpr index_t MPerBlock = TileShape::kM;
static constexpr index_t NPerBlock = TileShape::kN;
static constexpr index_t KPerBlock = TileShape::kK;
static constexpr index_t MWarps = BlockWarps::at(I0);
static constexpr index_t NWarps = BlockWarps::at(I1);
static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size");
static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0);
static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1);
static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static constexpr index_t K_Lane = get_warp_size() / 16; // 4
static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
public:
static constexpr index_t AK1 = DWORDx4 * APackedSize;
static constexpr index_t BK1 = DWORDx4 * BPackedSize;
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
{
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = WaveSize / K1; // 8
constexpr index_t M1 = BlockSize / WaveSize; // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
{
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = WaveSize / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 32
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<M0>{},
number<K0>{},
number<M1>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(K0),
make_pass_through_transform(M1),
make_pass_through_transform(M2),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
{
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
if constexpr(K_Thread == AK1)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>, sequence<K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
else
return make_static_tile_distribution(tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
sequence<K_Thread / AK1, K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>{});
}
// TODO: create also MakeMX_BAsyncLoadDramDescriptor, MakeMX_BDramTileDistribution MakeMX_BLdsBlockDescriptor for non-flat B
// to replace the below ones for flat B
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
{
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
if constexpr(BK1 == K_Thread)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>{});
else
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K_Thread / BK1, K0, K1, BK1 / BPackedSize>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>{});
}
template <typename WindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
{
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
auto&& byte_tensor_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
make_tuple(make_pass_through_transform(flat_n),
make_merge_transform_v3_division_mod(make_tuple(
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
auto&& byte_tensor_view =
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
auto&& origin_tmp = window_tmp.get_window_origin();
return make_tile_window(
byte_tensor_view,
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
{origin_tmp[0], origin_tmp[1] / BPackedSize},
MakeMX_BFlatBytesDramTileDistribution());
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
constexpr index_t K_Lanes = 64 / M_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = M_Lanes;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = MPerBlock / (MXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>, // repeat NWarps
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
constexpr index_t K_Lanes = 64 / N_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = N_Lanes;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = NPerBlock / (NXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>, // ?
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>, // repeat over NWarps
tuple<sequence<MWarps, MPerXdl>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>, // repeat over MWarps
tuple<sequence<NWarps, NPerXdl>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() /
APackedSize;
}
// TODO: add smem size for B
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// TODO: add B smem size
return GetSmemSizeA();
}
};
} // namespace ck_tile