mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
remove some old files
This commit is contained in:
@@ -1,599 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
private:
|
||||
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
// Controls how many MAC clusters (MFMA blocks) we have per wave
|
||||
// Ie if
|
||||
// InterWaveSchedulingMacClusters = 1;
|
||||
// KPerBlock == 32
|
||||
// WarpGemm::kK = 8
|
||||
// Then we would group all 4 WarpGemms into single MAC cluster.
|
||||
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
|
||||
// split those 4 warp gemms into two groups.
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
// should be at least equal to: WarpGemm::Impl::kABKPerLane
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow&,
|
||||
const BSmemBlockWindow&,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
|
||||
{
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
static constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
constexpr auto a_lds_load_distr = [&]() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeABlockDistributionEncode()),
|
||||
ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto b_lds_load_distr = [&]() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeBBlockDistributionEncode()),
|
||||
BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
|
||||
constexpr auto a_offset =
|
||||
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
constexpr auto b_offset =
|
||||
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC
|
||||
// cluster, but except the first, as we can shorten non-MAC cluster a bit
|
||||
// and there's no observable negative impact. The desired effect is waves in
|
||||
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
|
||||
// hijacking MAC resource from other workgroups and reducing the chance of
|
||||
// latency hiding by waiting for the rest of the workgroup at the eventual
|
||||
// sync point.
|
||||
if constexpr(kIter.value != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from
|
||||
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
|
||||
// by applying small delays to different wavefronts It is
|
||||
// performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(kIter.value == KRepeat - 1 &&
|
||||
kInnerIter.value == KInnerLoopIter - 1 &&
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
|
||||
nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
private:
|
||||
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,374 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXGemmPipelineAgBgCrPolicy<Problem>>
|
||||
struct MXGemmPipelineAgBgCrV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ComputeType = ADataType;
|
||||
static_assert(sizeof(ADataType) >= sizeof(BDataType));
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AsDataType = ck_tile::tuple<ADataType>;
|
||||
using BsDataType = ck_tile::tuple<BDataType>;
|
||||
using AsLayout = ck_tile::tuple<ALayout>;
|
||||
using BsLayout = ck_tile::tuple<BLayout>;
|
||||
using AElementWise = element_wise::PassThrough;
|
||||
using BElementWise = element_wise::PassThrough;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::GetBlockFlatmm())>;
|
||||
|
||||
static constexpr auto config =
|
||||
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t NumWaveGroups = BlockSize / WaveSize;
|
||||
static constexpr bool UsePersistentKernel = true;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
|
||||
static constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
static constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
static constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 0;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t /* num_loop */)
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
template <bool HasHotLoop, typename Callable>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(Callable&& f, bool /* has_hot_loop */, TailNumber /* tail_num */)
|
||||
{
|
||||
return f(bool_constant<HasHotLoop>{}, constant<TailNumber::Full>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return PipelinePolicy::GetSmemSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return APackedSize;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return BPackedSize;
|
||||
}
|
||||
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_DEVICE auto operator()(Args&&... args) const
|
||||
{
|
||||
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
|
||||
|
||||
// Block GEMM Acc register tile
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
});
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
{
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
// A DRAM Window
|
||||
auto a_dram_window =
|
||||
make_tile_window(PipelinePolicy::MakeMX_AAsyncLoadDramDescriptor(
|
||||
a_copy_dram_window_tmp.at(number<0>{}).get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.at(number<0>{}).get_window_lengths(),
|
||||
a_copy_dram_window_tmp.at(number<0>{}).get_window_origin(),
|
||||
PipelinePolicy::MakeMX_ADramTileDistribution());
|
||||
|
||||
// B DRAM Window
|
||||
auto b_dram_window =
|
||||
make_tile_window(PipelinePolicy::MakeMX_BAsyncLoadDramDescriptor(
|
||||
b_flat_dram_block_window_tmp.at(number<0>{}).get_bottom_tensor_view()),
|
||||
b_flat_dram_block_window_tmp.at(number<0>{}).get_window_lengths(),
|
||||
b_flat_dram_block_window_tmp.at(number<0>{}).get_window_origin(),
|
||||
PipelinePolicy::MakeMX_BDramTileDistribution());
|
||||
|
||||
// Scale A DRAM Window
|
||||
// With 1D K-only packing: window size is [MWarp * WG::kM, kKPerBlock / 32 / KXdlPack]
|
||||
constexpr index_t ScaleBlockSize = 32;
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * WG::kM>{}, number<kKPerBlock / ScaleBlockSize / KXdlPack>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
PipelinePolicy::MakeMX_ScaleA_FlatDramTileDistribution());
|
||||
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
|
||||
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<0>, number<kKPerBlock / ScaleBlockSize / KXdlPack>>{}));
|
||||
|
||||
// Scale B DRAM Window
|
||||
// With 1D K-only packing and [K/32/4, N] layout: window size is [kKPerBlock / 32 / KXdlPack, NWarp * WG::kN]
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kKPerBlock / ScaleBlockSize / KXdlPack>{}, number<NWarp * WG::kN>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
PipelinePolicy::MakeMX_ScaleB_DramTileDistribution());
|
||||
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<kKPerBlock / ScaleBlockSize / KXdlPack>, number<0>>{}));
|
||||
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * WG::kN>>{}));
|
||||
|
||||
// LDS Views
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr index_t a_lds_bytes = PipelinePolicy::GetSmemSizeA();
|
||||
BDataType* p_b_lds_ping = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_ping) + a_lds_bytes);
|
||||
BDataType* p_b_lds_pong = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_pong) + a_lds_bytes);
|
||||
|
||||
constexpr auto a_lds_block_desc = PipelinePolicy::MakeMX_ALdsBlockDescriptor();
|
||||
constexpr auto b_lds_block_desc = PipelinePolicy::MakeMX_BLdsBlockDescriptor();
|
||||
|
||||
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
auto b_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_b_lds_ping, b_lds_block_desc);
|
||||
auto b_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_b_lds_pong, b_lds_block_desc);
|
||||
|
||||
// Store Windows (for Async Copy)
|
||||
auto a_store_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto a_store_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto b_store_lds_window_ping = make_tile_window(b_lds_block_ping, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto b_store_lds_window_pong = make_tile_window(b_lds_block_pong, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Load Windows (for Warp Load)
|
||||
auto a_warp_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number<MWarp * WG::kM>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_ALDS_TileDistribution());
|
||||
auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number<MWarp * WG::kM>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_ALDS_TileDistribution());
|
||||
auto b_warp_window_ping = make_tile_window(b_lds_block_ping, make_tuple(number<NWarp * WG::kN>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_BLDS_TileDistribution());
|
||||
auto b_warp_window_pong = make_tile_window(b_lds_block_pong, make_tuple(number<NWarp * WG::kN>{}, number<WG::kK>{}), {0, 0}, PipelinePolicy::MakeMX_BLDS_TileDistribution());
|
||||
|
||||
// Register Tiles
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
|
||||
|
||||
// Initialize C
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
clear_tile(c_warp_tensors(mIter)(nIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Scale Tiles
|
||||
// With 1D K-only packing: one scale tile per M/N iter, indexed by K packed iter
|
||||
// K dimension: each K iter processes WG::kK elements, each int32 has KXdlPack scales covering KXdlPack*32 elements
|
||||
// So each KIterPerWarp needs KIterPerWarp/(KXdlPack) packed scale elements
|
||||
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * WG::kK) / (32 * KXdlPack);
|
||||
using ScaleATileType = statically_indexed_array<statically_indexed_array<decltype(load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{})), ScaleKPackedPerIter>, MIterPerWarp>;
|
||||
using ScaleBTileType = statically_indexed_array<statically_indexed_array<decltype(load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{})), ScaleKPackedPerIter>, NIterPerWarp>;
|
||||
|
||||
ScaleATileType scale_a_tile_ping, scale_a_tile_pong;
|
||||
ScaleBTileType scale_b_tile_ping, scale_b_tile_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
|
||||
};
|
||||
|
||||
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
|
||||
// Load scales for each M/N iteration
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
scale_a(mIter)(kPacked) = load_tile_with_offset(
|
||||
scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// Scale B is [K/32/4, N], so K is first dimension
|
||||
scale_b(nIter)(kPacked) = load_tile_with_offset(
|
||||
scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / ScaleBlockSize / KXdlPack});
|
||||
move_tile_window(scale_b_dram_window, {kKPerBlock / ScaleBlockSize / KXdlPack, 0});
|
||||
};
|
||||
|
||||
// Helper for Main Loop
|
||||
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
|
||||
// Define register tiles types for double buffering
|
||||
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));
|
||||
using BValType = decltype(load_tile_with_offset(b_warp_window, tuple<number<0>, number<0>>{}));
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AValType, MIterPerWarp>, 2> a_vals;
|
||||
statically_indexed_array<statically_indexed_array<BValType, NIterPerWarp>, 2> b_vals;
|
||||
|
||||
auto load_k = [&]<typename K, typename Buf>(const K&, const Buf& buf_idx) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
|
||||
a_vals(buf_idx)(m_iter) = load_tile_with_offset(
|
||||
a_warp_window,
|
||||
tuple<number<m_iter * MWarp * WG::kM>, number<K{} * WG::kK>>{});
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
b_vals(buf_idx)(n_iter) = load_tile_with_offset(
|
||||
b_warp_window,
|
||||
tuple<number<n_iter * NWarp * WG::kN>, number<K{} * WG::kK>>{});
|
||||
});
|
||||
};
|
||||
|
||||
// Prologue: Load K=0
|
||||
load_k(number<0>{}, number<0>{});
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
|
||||
constexpr auto cur_buf = k_iter % 2;
|
||||
constexpr auto nxt_buf = (k_iter + 1) % 2;
|
||||
|
||||
// Prefetch K+1
|
||||
if constexpr(k_iter < KIterPerWarp - 1) {
|
||||
load_k(number<k_iter + 1>{}, number<nxt_buf>{});
|
||||
}
|
||||
|
||||
// Map k_iter to packed scale index
|
||||
// Each k_iter processes WG::kK elements
|
||||
// Each packed int32 contains KXdlPack scales, each covering 32 elements
|
||||
// So we need k_iter * WG::kK / (32 * KXdlPack) to get the packed index
|
||||
// and k_iter * WG::kK / 32 % KXdlPack to get which scale within the pack
|
||||
constexpr index_t kScalePacked = (k_iter * WG::kK) / (32 * KXdlPack);
|
||||
constexpr index_t kScaleInPack = ((k_iter * WG::kK) / 32) % KXdlPack;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
|
||||
// OpSel selects which of the KXdlPack packed e8m0 values to use
|
||||
constexpr auto OpSelA = kScaleInPack;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
// OpSel selects which of the KXdlPack packed e8m0 values to use
|
||||
constexpr auto OpSelB = kScaleInPack;
|
||||
|
||||
WG{}.template operator()<OpSelA, OpSelB>(
|
||||
c_warp_tensors(m_iter)(n_iter),
|
||||
bit_cast<typename WG::AWarpTensor>(a_vals(number<cur_buf>{})(m_iter)),
|
||||
bit_cast<typename WG::BWarpTensor>(b_vals(number<cur_buf>{})(n_iter)),
|
||||
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
|
||||
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// Prologue: Load first block
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
async_load_tile_(b_store_lds_window_ping, b_dram_window);
|
||||
|
||||
// Load Scales (Ping - Iter 0)
|
||||
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
|
||||
|
||||
// Load Scales (Pong - Iter 1)
|
||||
if (num_loop > 1) {
|
||||
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
|
||||
}
|
||||
|
||||
// Move DRAM windows
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_dram_window, {0, kKPerBlock});
|
||||
// Scale windows already moved in load_scales_
|
||||
|
||||
// Main Loop
|
||||
index_t i = 0;
|
||||
do {
|
||||
// Wait for LDS load
|
||||
s_waitcnt<0>();
|
||||
block_sync_lds();
|
||||
|
||||
// Trigger next load (Ping-Pong)
|
||||
if (i < num_loop - 1) {
|
||||
if (i % 2 == 0) {
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
async_load_tile_(b_store_lds_window_pong, b_dram_window);
|
||||
} else {
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
async_load_tile_(b_store_lds_window_ping, b_dram_window);
|
||||
}
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_dram_window, {0, kKPerBlock});
|
||||
}
|
||||
|
||||
// Compute
|
||||
if (i % 2 == 0) {
|
||||
warp_gemm_loop(a_warp_window_ping, b_warp_window_ping, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// Load next scales (Ping - Iter i+2)
|
||||
if (i + 2 < num_loop) {
|
||||
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
|
||||
}
|
||||
} else {
|
||||
warp_gemm_loop(a_warp_window_pong, b_warp_window_pong, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// Load next scales (Pong - Iter i+2)
|
||||
if (i + 2 < num_loop) {
|
||||
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
|
||||
}
|
||||
}
|
||||
|
||||
i++;
|
||||
} while (i < num_loop);
|
||||
|
||||
return c_warp_tensors;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,548 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
static constexpr index_t DWORDx4 = 16;
|
||||
|
||||
static constexpr int MXdlPack = 1; // No M packing
|
||||
static constexpr int NXdlPack = 1; // No N packing
|
||||
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename TileShape::BlockWarps;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
static constexpr index_t MPerBlock = TileShape::kM;
|
||||
static constexpr index_t NPerBlock = TileShape::kN;
|
||||
static constexpr index_t KPerBlock = TileShape::kK;
|
||||
static constexpr index_t MWarps = BlockWarps::at(I0);
|
||||
static constexpr index_t NWarps = BlockWarps::at(I1);
|
||||
static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size");
|
||||
|
||||
static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0);
|
||||
static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1);
|
||||
static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2);
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static constexpr index_t K_Lane = get_warp_size() / 16; // 4
|
||||
static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
|
||||
|
||||
public:
|
||||
static constexpr index_t AK1 = DWORDx4 * APackedSize;
|
||||
static constexpr index_t BK1 = DWORDx4 * BPackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
const auto& naive_desc = naive_view.get_tensor_descriptor();
|
||||
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
const auto rows = naive_desc.get_length(number<0>{});
|
||||
const auto cols = naive_desc.get_length(number<1>{});
|
||||
|
||||
constexpr index_t K2 = AK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
const index_t K0 = cols / (K1 * K2);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
const index_t M0 = rows / M1;
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto desc_0 =
|
||||
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc = transform_tensor_descriptor( //
|
||||
desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
|
||||
|
||||
return tensor_view<typename TensorView::buffer_view,
|
||||
remove_cvref_t<decltype(desc)>,
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = AK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t M2 = WaveSize / K1; // 8
|
||||
constexpr index_t M1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // M0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = AK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
constexpr index_t M2 = WaveSize / K1 / M3; // 2
|
||||
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
|
||||
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
|
||||
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
|
||||
|
||||
constexpr index_t Pad = 4 * K2; // 4 * 32
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<M0>{},
|
||||
number<K0>{},
|
||||
number<M1>{},
|
||||
number<M2>{},
|
||||
number<M3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
|
||||
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
number<M3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(M2),
|
||||
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}));
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
|
||||
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc_permuted;
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
|
||||
{
|
||||
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
if constexpr(K_Thread == AK1)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>, sequence<K_Lane, AK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
else
|
||||
return make_static_tile_distribution(tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
|
||||
sequence<K_Thread / AK1, K_Lane, AK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>, // K_Thread/AK1, AK1
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMX_BAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
const auto& naive_desc = naive_view.get_tensor_descriptor();
|
||||
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
const auto rows = naive_desc.get_length(number<0>{});
|
||||
const auto cols = naive_desc.get_length(number<1>{});
|
||||
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
const index_t K0 = cols / (K1 * K2);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t N1 = 4; // so that we can use imm offset to load lds
|
||||
const index_t N0 = rows / N1;
|
||||
const auto row_lens = make_tuple(N0, number<N1>{});
|
||||
|
||||
const auto desc_0 =
|
||||
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(N0),
|
||||
make_xor_transform(make_tuple(number<N1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc = transform_tensor_descriptor( //
|
||||
desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return tensor_view<typename TensorView::buffer_view,
|
||||
remove_cvref_t<decltype(desc)>,
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr index_t N3 = 4; // so that we can use imm offset to load lds
|
||||
constexpr index_t N2 = WaveSize / K1 / N3; // 2
|
||||
constexpr index_t N1 = NPerXdl / (N2 * N3); // 2
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2 * N3); // NPerBlock/16
|
||||
static_assert(N0 * N1 * N2 * N3 == NPerBlock, "N0, N1, N2, N3 must cover whole NPerBlock!");
|
||||
|
||||
constexpr index_t Pad = 4 * K2; // 4 * 32
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<N0>{},
|
||||
number<K0>{},
|
||||
number<N1>{},
|
||||
number<N2>{},
|
||||
number<N3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<K0*(N1 * (N2 * N3 * K1 * K2) + (N1 - 1) * Pad)>{},
|
||||
number<N1*(N2 * N3 * K1 * K2) + (N1 - 1) * Pad>{},
|
||||
number<N2 * N3 * K1 * K2 + Pad>{},
|
||||
number<N3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(N0),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(N1),
|
||||
make_pass_through_transform(N2),
|
||||
make_xor_transform(make_tuple(number<N3>{}, number<K1>{})),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}));
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<N0>{}, number<N1>{}, number<N2>{}, number<N3>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
|
||||
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDS_TileDistribution()
|
||||
{
|
||||
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
if constexpr(K_Thread == BK1)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<MWarps>,
|
||||
tuple<sequence<NWarps, NXdlPack, NPerXdl>, sequence<K_Lane, BK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
else
|
||||
return make_static_tile_distribution(tile_distribution_encoding< //
|
||||
sequence<MWarps>,
|
||||
tuple<sequence<NWarps, NXdlPack, NPerXdl>,
|
||||
sequence<K_Thread / BK1, K_Lane, BK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
// TODO: create also MakeMX_BAsyncLoadDramDescriptor, MakeMX_BDramTileDistribution MakeMX_BLdsBlockDescriptor for non-flat B
|
||||
// to replace the below ones for flat B
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K1 = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t K0 = KWavePerBlk;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
if constexpr(BK1 == K_Thread)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
else
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K_Thread / BK1, K0, K1, BK1 / BPackedSize>>, // 2 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename WindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
|
||||
{
|
||||
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
|
||||
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
|
||||
|
||||
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
|
||||
auto&& byte_tensor_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
|
||||
make_tuple(make_pass_through_transform(flat_n),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
return make_tile_window(
|
||||
byte_tensor_view,
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
MakeMX_BFlatBytesDramTileDistribution());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
// With 1D K-only packing: MXdlPack=1, so no complex M packing
|
||||
// Simple 2D distribution for [M, K/32/KXdlPack] layout
|
||||
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition - no packing factor
|
||||
constexpr index_t Y2 = M_Lanes;
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y0 = MPerBlock / (Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition - each int32 contains KXdlPack scales
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // vec load of int32
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>, // repeat NWarps
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
|
||||
{
|
||||
// With 1D K-only packing and [K/32/4, N] layout to match reference
|
||||
// Layout is [K, N] where K is packed int32
|
||||
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Lanes = 64 / N_Lanes;
|
||||
|
||||
// First tuple element: K dimension decomposition
|
||||
constexpr index_t K0 = K_Lanes;
|
||||
constexpr index_t K1 = 1; // vec load of int32
|
||||
|
||||
// Second tuple element: N dimension decomposition
|
||||
constexpr index_t N2 = N_Lanes;
|
||||
constexpr index_t N1 = NWarps;
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps>, // repeat MWarps
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<2, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 0>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
|
||||
{
|
||||
// With 1D K-only packing: simpler distribution for [MWarp*MPerXdl, K/32/KXdlPack]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>, // repeat over NWarps
|
||||
tuple<sequence<MWarps, MPerXdl>, // M dimension
|
||||
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
|
||||
{
|
||||
// With 1D K-only packing and [K/32/4, N] layout: [K/32/KXdlPack, NWarp*NPerXdl]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps>, // repeat over MWarps
|
||||
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
|
||||
sequence<NWarps, NPerXdl>>, // N dimension
|
||||
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
|
||||
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() /
|
||||
APackedSize;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
return sizeof(BDataType) * MakeMX_BLdsBlockDescriptor().get_element_space_size() /
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeA() + GetSmemSizeB();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user