remove some old files

This commit is contained in:
Sami Remes
2026-02-06 18:37:34 +00:00
parent 457474ed90
commit c7298e57c0
4 changed files with 0 additions and 2638 deletions

View File

@@ -1,599 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
// Controls how many MAC clusters (MFMA blocks) we have per wave
// Ie if
// InterWaveSchedulingMacClusters = 1;
// KPerBlock == 32
// WarpGemm::kK = 8
// Then we would group all 4 WarpGemms into single MAC cluster.
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
// split those 4 warp gemms into two groups.
static constexpr index_t InterWaveSchedulingMacClusters = 1;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow&,
const BSmemBlockWindow&,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C = A * B
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
return c_block_tensor;
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile

View File

@@ -1,374 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp"
namespace ck_tile {
template <typename Problem, typename PipelinePolicy = MXGemmPipelineAgBgCrPolicy<Problem>>
struct MXGemmPipelineAgBgCrV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ComputeType = ADataType;
static_assert(sizeof(ADataType) >= sizeof(BDataType));
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsDataType = ck_tile::tuple<ADataType>;
using BsDataType = ck_tile::tuple<BDataType>;
using AsLayout = ck_tile::tuple<ALayout>;
using BsLayout = ck_tile::tuple<BLayout>;
using AElementWise = element_wise::PassThrough;
using BElementWise = element_wise::PassThrough;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::GetBlockFlatmm())>;
static constexpr auto config =
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t NumWaveGroups = BlockSize / WaveSize;
static constexpr bool UsePersistentKernel = true;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
static constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
static constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > 0;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t /* num_loop */)
{
return TailNumber::Full;
}
template <bool HasHotLoop, typename Callable>
CK_TILE_HOST_DEVICE static auto TailHandler(Callable&& f, bool /* has_hot_loop */, TailNumber /* tail_num */)
{
return f(bool_constant<HasHotLoop>{}, constant<TailNumber::Full>{});
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return PipelinePolicy::GetSmemSize();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
return APackedSize;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
return BPackedSize;
}
static constexpr bool Preshuffle = false;
template <typename... Args>
CK_TILE_DEVICE auto operator()(Args&&... args) const
{
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
// Block GEMM Acc register tile
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
});
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
{
using CWarpTensor = typename WG::CWarpTensor;
// A DRAM Window
auto a_dram_window =
make_tile_window(PipelinePolicy::MakeMX_AAsyncLoadDramDescriptor(
a_copy_dram_window_tmp.at(number<0>{}).get_bottom_tensor_view()),
a_copy_dram_window_tmp.at(number<0>{}).get_window_lengths(),
a_copy_dram_window_tmp.at(number<0>{}).get_window_origin(),
PipelinePolicy::MakeMX_ADramTileDistribution());
// B DRAM Window
auto b_dram_window =
make_tile_window(PipelinePolicy::MakeMX_BAsyncLoadDramDescriptor(
b_flat_dram_block_window_tmp.at(number<0>{}).get_bottom_tensor_view()),
b_flat_dram_block_window_tmp.at(number<0>{}).get_window_lengths(),
b_flat_dram_block_window_tmp.at(number<0>{}).get_window_origin(),
PipelinePolicy::MakeMX_BDramTileDistribution());
// Scale A DRAM Window
// With 1D K-only packing: window size is [MWarp * WG::kM, kKPerBlock / 32 / KXdlPack]
constexpr index_t ScaleBlockSize = 32;
auto scale_a_dram_window = make_tile_window(
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * WG::kM>{}, number<kKPerBlock / ScaleBlockSize / KXdlPack>{}),
scale_a_window.get_window_origin(),
PipelinePolicy::MakeMX_ScaleA_FlatDramTileDistribution());
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<0>, number<kKPerBlock / ScaleBlockSize / KXdlPack>>{}));
// Scale B DRAM Window
// With 1D K-only packing and [K/32/4, N] layout: window size is [kKPerBlock / 32 / KXdlPack, NWarp * WG::kN]
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<kKPerBlock / ScaleBlockSize / KXdlPack>{}, number<NWarp * WG::kN>{}),
scale_b_window.get_window_origin(),
PipelinePolicy::MakeMX_ScaleB_DramTileDistribution());
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<kKPerBlock / ScaleBlockSize / KXdlPack>, number<0>>{}));
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * WG::kN>>{}));
// LDS Views
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr index_t a_lds_bytes = PipelinePolicy::GetSmemSizeA();
BDataType* p_b_lds_ping = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_ping) + a_lds_bytes);
BDataType* p_b_lds_pong = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_pong) + a_lds_bytes);
constexpr auto a_lds_block_desc = PipelinePolicy::MakeMX_ALdsBlockDescriptor();
constexpr auto b_lds_block_desc = PipelinePolicy::MakeMX_BLdsBlockDescriptor();
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto b_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_b_lds_ping, b_lds_block_desc);
auto b_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_b_lds_pong, b_lds_block_desc);
// Store Windows (for Async Copy)
auto a_store_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window_ping = make_tile_window(b_lds_block_ping, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window_pong = make_tile_window(b_lds_block_pong, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Load Windows (for Warp Load)
auto a_warp_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number<MWarp * WG::kM>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_ALDS_TileDistribution());
auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number<MWarp * WG::kM>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_ALDS_TileDistribution());
auto b_warp_window_ping = make_tile_window(b_lds_block_ping, make_tuple(number<NWarp * WG::kN>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_BLDS_TileDistribution());
auto b_warp_window_pong = make_tile_window(b_lds_block_pong, make_tuple(number<NWarp * WG::kN>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_BLDS_TileDistribution());
// Register Tiles
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
// Initialize C
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
clear_tile(c_warp_tensors(mIter)(nIter));
});
});
// Scale Tiles
// With 1D K-only packing: one scale tile per M/N iter, indexed by K packed iter
// K dimension: each K iter processes WG::kK elements, each int32 has KXdlPack scales covering KXdlPack*32 elements
// So each KIterPerWarp needs KIterPerWarp/(KXdlPack) packed scale elements
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * WG::kK) / (32 * KXdlPack);
using ScaleATileType = statically_indexed_array<statically_indexed_array<decltype(load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{})), ScaleKPackedPerIter>, MIterPerWarp>;
using ScaleBTileType = statically_indexed_array<statically_indexed_array<decltype(load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{})), ScaleKPackedPerIter>, NIterPerWarp>;
ScaleATileType scale_a_tile_ping, scale_a_tile_pong;
ScaleBTileType scale_b_tile_ping, scale_b_tile_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
};
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
// Load scales for each M/N iteration
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
scale_a(mIter)(kPacked) = load_tile_with_offset(
scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// Scale B is [K/32/4, N], so K is first dimension
scale_b(nIter)(kPacked) = load_tile_with_offset(
scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
});
});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / ScaleBlockSize / KXdlPack});
move_tile_window(scale_b_dram_window, {kKPerBlock / ScaleBlockSize / KXdlPack, 0});
};
// Helper for Main Loop
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
// Define register tiles types for double buffering
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));
using BValType = decltype(load_tile_with_offset(b_warp_window, tuple<number<0>, number<0>>{}));
statically_indexed_array<statically_indexed_array<AValType, MIterPerWarp>, 2> a_vals;
statically_indexed_array<statically_indexed_array<BValType, NIterPerWarp>, 2> b_vals;
auto load_k = [&]<typename K, typename Buf>(const K&, const Buf& buf_idx) {
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
a_vals(buf_idx)(m_iter) = load_tile_with_offset(
a_warp_window,
tuple<number<m_iter * MWarp * WG::kM>, number<K{} * WG::kK>>{});
});
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
b_vals(buf_idx)(n_iter) = load_tile_with_offset(
b_warp_window,
tuple<number<n_iter * NWarp * WG::kN>, number<K{} * WG::kK>>{});
});
};
// Prologue: Load K=0
load_k(number<0>{}, number<0>{});
static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
constexpr auto cur_buf = k_iter % 2;
constexpr auto nxt_buf = (k_iter + 1) % 2;
// Prefetch K+1
if constexpr(k_iter < KIterPerWarp - 1) {
load_k(number<k_iter + 1>{}, number<nxt_buf>{});
}
// Map k_iter to packed scale index
// Each k_iter processes WG::kK elements
// Each packed int32 contains KXdlPack scales, each covering 32 elements
// So we need k_iter * WG::kK / (32 * KXdlPack) to get the packed index
// and k_iter * WG::kK / 32 % KXdlPack to get which scale within the pack
constexpr index_t kScalePacked = (k_iter * WG::kK) / (32 * KXdlPack);
constexpr index_t kScaleInPack = ((k_iter * WG::kK) / 32) % KXdlPack;
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
// OpSel selects which of the KXdlPack packed e8m0 values to use
constexpr auto OpSelA = kScaleInPack;
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
// OpSel selects which of the KXdlPack packed e8m0 values to use
constexpr auto OpSelB = kScaleInPack;
WG{}.template operator()<OpSelA, OpSelB>(
c_warp_tensors(m_iter)(n_iter),
bit_cast<typename WG::AWarpTensor>(a_vals(number<cur_buf>{})(m_iter)),
bit_cast<typename WG::BWarpTensor>(b_vals(number<cur_buf>{})(n_iter)),
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
});
});
});
};
// Prologue: Load first block
async_load_tile_(a_store_lds_window_ping, a_dram_window);
async_load_tile_(b_store_lds_window_ping, b_dram_window);
// Load Scales (Ping - Iter 0)
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
// Load Scales (Pong - Iter 1)
if (num_loop > 1) {
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
}
// Move DRAM windows
move_tile_window(a_dram_window, {0, kKPerBlock});
move_tile_window(b_dram_window, {0, kKPerBlock});
// Scale windows already moved in load_scales_
// Main Loop
index_t i = 0;
do {
// Wait for LDS load
s_waitcnt<0>();
block_sync_lds();
// Trigger next load (Ping-Pong)
if (i < num_loop - 1) {
if (i % 2 == 0) {
async_load_tile_(a_store_lds_window_pong, a_dram_window);
async_load_tile_(b_store_lds_window_pong, b_dram_window);
} else {
async_load_tile_(a_store_lds_window_ping, a_dram_window);
async_load_tile_(b_store_lds_window_ping, b_dram_window);
}
move_tile_window(a_dram_window, {0, kKPerBlock});
move_tile_window(b_dram_window, {0, kKPerBlock});
}
// Compute
if (i % 2 == 0) {
warp_gemm_loop(a_warp_window_ping, b_warp_window_ping, scale_a_tile_ping, scale_b_tile_ping);
// Load next scales (Ping - Iter i+2)
if (i + 2 < num_loop) {
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
}
} else {
warp_gemm_loop(a_warp_window_pong, b_warp_window_pong, scale_a_tile_pong, scale_b_tile_pong);
// Load next scales (Pong - Iter i+2)
if (i + 2 < num_loop) {
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
}
}
i++;
} while (i < num_loop);
return c_warp_tensors;
}
};
} // namespace ck_tile

View File

@@ -1,548 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
namespace ck_tile {
template <typename Problem>
struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr index_t DWORDx4 = 16;
static constexpr int MXdlPack = 1; // No M packing
static constexpr int NXdlPack = 1; // No N packing
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
using TileShape = typename Problem::BlockGemmShape;
using BlockWarps = typename TileShape::BlockWarps;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t WaveNum = BlockSize / WaveSize;
static constexpr index_t MPerBlock = TileShape::kM;
static constexpr index_t NPerBlock = TileShape::kN;
static constexpr index_t KPerBlock = TileShape::kK;
static constexpr index_t MWarps = BlockWarps::at(I0);
static constexpr index_t NWarps = BlockWarps::at(I1);
static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size");
static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0);
static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1);
static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static constexpr index_t K_Lane = get_warp_size() / 16; // 4
static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
public:
static constexpr index_t AK1 = DWORDx4 * APackedSize;
static constexpr index_t BK1 = DWORDx4 * BPackedSize;
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
{
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = WaveSize / K1; // 8
constexpr index_t M1 = BlockSize / WaveSize; // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
{
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = WaveSize / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 32
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<M0>{},
number<K0>{},
number<M1>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(K0),
make_pass_through_transform(M1),
make_pass_through_transform(M2),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
{
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
if constexpr(K_Thread == AK1)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>, sequence<K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
else
return make_static_tile_distribution(tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
sequence<K_Thread / AK1, K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>, // K_Thread/AK1, AK1
sequence<0, 2>>{});
}
template <typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMX_BAsyncLoadDramDescriptor(const TensorView& naive_view)
{
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t K2 = BK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t N1 = 4; // so that we can use imm offset to load lds
const index_t N0 = rows / N1;
const auto row_lens = make_tuple(N0, number<N1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(N0),
make_xor_transform(make_tuple(number<N1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
{
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
constexpr index_t K2 = BK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t N2 = WaveSize / K1; // 8
constexpr index_t N1 = BlockSize / WaveSize; // 4
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // N0,K0,K2
sequence<0, 0, 2>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
{
constexpr index_t K2 = BK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t N3 = 4; // so that we can use imm offset to load lds
constexpr index_t N2 = WaveSize / K1 / N3; // 2
constexpr index_t N1 = NPerXdl / (N2 * N3); // 2
constexpr index_t N0 = NPerBlock / (N1 * N2 * N3); // NPerBlock/16
static_assert(N0 * N1 * N2 * N3 == NPerBlock, "N0, N1, N2, N3 must cover whole NPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 32
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<N0>{},
number<K0>{},
number<N1>{},
number<N2>{},
number<N3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<K0*(N1 * (N2 * N3 * K1 * K2) + (N1 - 1) * Pad)>{},
number<N1*(N2 * N3 * K1 * K2) + (N1 - 1) * Pad>{},
number<N2 * N3 * K1 * K2 + Pad>{},
number<N3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto b_lds_block_desc_1 = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(N0),
make_pass_through_transform(K0),
make_pass_through_transform(N1),
make_pass_through_transform(N2),
make_xor_transform(make_tuple(number<N3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<N0>{}, number<N1>{}, number<N2>{}, number<N3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDS_TileDistribution()
{
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
if constexpr(K_Thread == BK1)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<MWarps>,
tuple<sequence<NWarps, NXdlPack, NPerXdl>, sequence<K_Lane, BK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
else
return make_static_tile_distribution(tile_distribution_encoding< //
sequence<MWarps>,
tuple<sequence<NWarps, NXdlPack, NPerXdl>,
sequence<K_Thread / BK1, K_Lane, BK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>{});
}
// TODO: create also MakeMX_BAsyncLoadDramDescriptor, MakeMX_BDramTileDistribution MakeMX_BLdsBlockDescriptor for non-flat B
// to replace the below ones for flat B
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
{
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
if constexpr(BK1 == K_Thread)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>{});
else
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K_Thread / BK1, K0, K1, BK1 / BPackedSize>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>{});
}
template <typename WindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
{
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
auto&& byte_tensor_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
make_tuple(make_pass_through_transform(flat_n),
make_merge_transform_v3_division_mod(make_tuple(
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
auto&& byte_tensor_view =
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
auto&& origin_tmp = window_tmp.get_window_origin();
return make_tile_window(
byte_tensor_view,
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
{origin_tmp[0], origin_tmp[1] / BPackedSize},
MakeMX_BFlatBytesDramTileDistribution());
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
// With 1D K-only packing: MXdlPack=1, so no complex M packing
// Simple 2D distribution for [M, K/32/KXdlPack] layout
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
constexpr index_t K_Lanes = 64 / M_Lanes;
// Y dimension (M) decomposition - no packing factor
constexpr index_t Y2 = M_Lanes;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = MPerBlock / (Y1 * Y2);
// X dimension (K) decomposition - each int32 contains KXdlPack scales
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // vec load of int32
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>, // repeat NWarps
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
// With 1D K-only packing and [K/32/4, N] layout to match reference
// Layout is [K, N] where K is packed int32
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
constexpr index_t K_Lanes = 64 / N_Lanes;
// First tuple element: K dimension decomposition
constexpr index_t K0 = K_Lanes;
constexpr index_t K1 = 1; // vec load of int32
// Second tuple element: N dimension decomposition
constexpr index_t N2 = N_Lanes;
constexpr index_t N1 = NWarps;
constexpr index_t N0 = NPerBlock / (N1 * N2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>, // repeat MWarps
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<2, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 0>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
{
// With 1D K-only packing: simpler distribution for [MWarp*MPerXdl, K/32/KXdlPack]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>, // repeat over NWarps
tuple<sequence<MWarps, MPerXdl>, // M dimension
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
{
// With 1D K-only packing and [K/32/4, N] layout: [K/32/KXdlPack, NWarp*NPerXdl]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>, // repeat over MWarps
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
sequence<NWarps, NPerXdl>>, // N dimension
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
// <repeat, vec_load>
sequence<1>,
sequence<2>>{});
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() /
APackedSize;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
return sizeof(BDataType) * MakeMX_BLdsBlockDescriptor().get_element_space_size() /
BPackedSize;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA() + GetSmemSizeB();
}
};
} // namespace ck_tile