This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,410 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmPipelineAgBgCrImplBase
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
template <typename T>
using has_bcastpolicy_type = decltype(T::BCastPolicy);
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
{
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
}
else
{
return false;
}
}();
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite, ADataType, BInDataType>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes compilation errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
static constexpr bool is_a_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<ADataType, float>)
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();
static constexpr bool is_b_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, float>)
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;
#endif
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_and_convert_tile<UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
async_load_tile(dst_block_window, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename DstTileWindow, typename SrcBlockTile>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile) const
{
store_tile(lds_tile_window, src_block_tile);
}
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window,
bool_constant<LoadTranspose> = {}) const
{
if constexpr(LoadTranspose)
load_tile_transpose(dst_block_tile, lds_tile_window);
else
load_tile(dst_block_tile, lds_tile_window);
}
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
constexpr auto a_lds_block_desc =
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size =
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(a_lds_block_space_size, 16);
// B tile in LDS
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
Policy::template MakeADramTileDistribution<Problem>());
},
number<DramBlockWindowTmp::size()>{});
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeADramTileDistribution<Problem>());
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
Policy::template MakeBDramTileDistribution<Problem>());
},
number<DramBlockWindowTmp::size()>{});
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeBDramTileDistribution<Problem>());
return std::move(a_copy_dram_window);
}
template <typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
auto a_lds_load_tile_distr = []() {
if constexpr(is_a_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename ALdsLoadTileDistr::DstrEncode,
typename ALdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return ALdsLoadTileDistr{};
}
}();
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window));
}
template <
typename ADramBlockWindowTmp,
typename ALdsTensorView,
typename ALdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, ALdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// Create LDS windows
auto [a_copy_lds_window, a_lds_gemm_window] =
MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
// Unified GetAWindows that supports 1, 2, or 3 LDS buffers
template <typename ADramBlockWindowTmp,
typename ALdsTensorViewsTuple,
typename ALdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, ALdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorViewsTuple& a_lds_block_views_tuple,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// Create LDS windows for each buffer
constexpr index_t num_buffers = ALdsTensorViewsTuple::size();
auto a_lds_windows = generate_tuple(
[&](auto i) {
return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows));
}
template <typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
auto b_lds_shape = []() {
if constexpr(is_b_load_tr)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename BLdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return BLdsLoadTileDistr{};
}
}();
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window));
}
template <
typename BDramBlockWindowTmp,
typename BLdsTensorView,
typename BLdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, BLdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows
auto [b_copy_lds_window, b_lds_gemm_window] =
MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
// Unified GetBWindows that supports 1, 2, or 3 LDS buffers
template <typename BDramBlockWindowTmp,
typename BLdsTensorViewsTuple,
typename BLdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, BLdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorViewsTuple& b_lds_block_views_tuple,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// B DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows for each buffer
constexpr index_t num_buffers = BLdsTensorViewsTuple::size();
auto b_lds_windows = generate_tuple(
[&](auto i) {
return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows));
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,653 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompAsync
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop_first_lane)
{
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
else
{
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error(
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
#endif
}
CK_TILE_HOST static constexpr auto GetName() { return "COMPUTE_ASYNC"; }
};
/**
* @brief Compute optimized pipeline version async; which is based on V4.
*
* This pipeline introduces asynchronous load from global memory to LDS,
* skipping the intermediate loading into pipeline registers.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = true;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_ASYNC";
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = get_warp_size();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
// TODO: this will likely need to be redesigned after (1) changes to reading from
// LDS and (2) re-profiling
ignore = i;
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
// TODO currently fused elementwise are not supported
ignore = a_element_func;
ignore = b_element_func;
static_assert(std::is_same_v<remove_cvref_t<decltype(a_element_func)>,
element_wise::PassThrough>);
static_assert(std::is_same_v<remove_cvref_t<decltype(b_element_func)>,
element_wise::PassThrough>);
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// B DRAM window(s) for load
auto b_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
// this pipeline has a pair of LDS buffers per logical tile
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
// set up LDS tile shapes
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
// LDS tile windows for storing, one per LDS buffer
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
// initialize DRAM window steps, used to advance the DRAM windows
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// initialize block gemm
auto block_gemm = BlockGemm();
// initialize C block tile
auto c_block_tile = block_gemm.MakeCBlockTile();
clear_tile(c_block_tile);
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
if constexpr(is_a_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(ALdsTileDistr)::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
if constexpr(is_b_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(BLdsTileDistr)::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
"LDS windows must not be linear");
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(0), B(0) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// LDS window(0) contents are overwritten below by global prefetch, need to sync
block_sync_lds();
// read A(2), B(2) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
if constexpr(HasHotLoop)
{
// we have had 3 global prefetches so far, indexed (0, 1, 2).
index_t i_global_read = amd_wave_read_first_lane(3);
// alternate ping: (read to register tile(1), use register tile(0) as gemm input)
// pong: (read to register tile(0), use register tile(1) as gemm input)
do
{
// ping
{
// read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// LDS window(1) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i), B(i) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window1,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window1,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(i), B(i) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// LDS window(0) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i+1), B(i+1) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window0,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window0,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
i_global_read += 2;
} while(i_global_read < num_loop);
}
// 3 block gemms remaining
if constexpr(TailNum == TailNumber::Three)
{
{
// read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
{
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
}
else if(TailNum == TailNumber::Two)
// 2 block gemms remaining
{
{
// read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
public:
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
ck_tile::make_tuple(a_dram_block_window_tmp),
element_wise::PassThrough{},
ck_tile::make_tuple(b_dram_block_window_tmp),
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,132 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAgBgCrCompAsync
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
{
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
// This branch is reusing the logic from
// UniversalGemmBasePolicy::MakeALdsBlockDescriptor
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<1>{}),
number<MPerBlock>{},
number<1>{});
return a_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
// This branch is reusing the logic from
// UniversalGemmBasePolicy::MakeBLdsBlockDescriptor
constexpr auto b_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<1>{}),
number<NPerBlock>{},
number<1>{});
return b_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,794 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV3
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
return num_loop > 3;
else
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(BlockHasHotloop(num_loop) || num_loop == 3)
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
else
return TailNumber::Odd;
else if(num_loop == 2)
return TailNumber::Even;
else
return (Problem::BlockGemmShape::NumWarps == 8) ? TailNumber::One : TailNumber::Odd;
}
template <size_t I = 0, typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
constexpr auto scenarios = []() {
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
return std::array<std::pair<bool, ck_tile::TailNumber>, 5>{
std::make_pair(false, TailNumber::One), // 1 loop
std::make_pair(false, TailNumber::Even), // 2 loop
std::make_pair(false, TailNumber::Odd), // 3
std::make_pair(true, TailNumber::Even), // 4 / 6 / 8 / ... loops
std::make_pair(true, TailNumber::Odd), // 5 / 7 / 9 / ... loops
};
else
return std::array<std::pair<bool, ck_tile::TailNumber>, 3>{
std::make_pair(true, TailNumber::Odd),
std::make_pair(false, TailNumber::Odd),
std::make_pair(false, TailNumber::Even),
};
}();
if(has_hot_loop_first_lane == scenarios[I].first &&
tail_number_first_lane == scenarios[I].second)
return run_func(bool_constant<scenarios[I].first>{}, constant<scenarios[I].second>{});
else if constexpr(I + 1 < scenarios.size())
return TailHandler<I + 1>(run_func, has_hot_loop, tail_number);
#if defined(__HIP_DEVICE_COMPILE__)
// This path should be unreachable in device code if tail_number is valid.
__builtin_unreachable();
#else
// If execution reaches here, it's an invalid combination of arguments.
throw std::logic_error("Invalid TailNumber value: must be "
"TailNumber::Odd or TailNumber::Even");
#endif
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
using Base::UsePersistentKernel;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V3";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK),
Problem::GetName());
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
// global read 1
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
elementwise_As_res =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Odd)
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
/**
* @brief This function runs the pipeline by wrapping it with the tail handler.
*
* @note This is used by the persistent gemm kernel variants that don't determine
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
*/
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
/**
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
* SplitK.
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
return operator()(a_dram_block_window_tmp,
b_dram_block_window_tmp,
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
}
/**
* @brief Quant operator(), single input: This function runs the pipeline by wrapping it with
* the tail handler.
*
* @note This is used by the persistent gemm kernel variants that don't determine
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
*/
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
/**
* @brief Quant operator(), single input: This function runs the pipeline using compile-time
* known hot loop and tail number.
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
* SplitK.
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,823 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
constexpr index_t HotLoopGlobalReads = 2;
return num_loop >= (HotLoopGlobalReads + PrefetchStages);
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop_first_lane)
{
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
else
{
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
"PrefetchStages are supported.");
#endif
}
};
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V4";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV4",
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// global prefetch 0
// global read 0
////////////// LDS desc, window & register /////////////////
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v())
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v())
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Generating a tuple with tile_windows for values A0, A1, ... AN
auto a_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_tile_windows, a_dram_tile_window_step);
// Generating a tuple with tile_windows for values B0, B1, ... BN
auto b_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
// global read 1
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
block_sync_lds();
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
constexpr auto a_lds_input_tile_distr = [&]() {
if constexpr(is_a_load_tr_v())
return make_static_tile_distribution(
typename InputTileDistributionTraits<
decltype(BlockGemm::MakeABlockDistributionEncode()),
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [&]() {
if constexpr(is_b_load_tr_v())
return make_static_tile_distribution(
typename InputTileDistributionTraits<
decltype(BlockGemm::MakeBBlockDistributionEncode()),
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
static_assert(!is_tile_window_linear_v<decltype(a_lds_ld_window0)> &&
!is_tile_window_linear_v<decltype(a_lds_ld_window1)> &&
!is_tile_window_linear_v<decltype(b_lds_ld_window0)> &&
!is_tile_window_linear_v<decltype(b_lds_ld_window1)>,
"LDS windows must not be linear");
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
}
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
if constexpr(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
index_t iCounter = amd_wave_read_first_lane(num_loop - 2);
do
{
// ping
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
elementwise_As_res =
load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
}
block_sync_lds();
elementwise_As_res =
load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
iCounter -= 2;
} while(iCounter > 1);
}
// tail 3
if(TailNum == TailNumber::Three)
{
// 3
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// 1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
else if(TailNum == TailNumber::Two)
{
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
// 1
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_barrier(0);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,56 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAgBgCrCompV4DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV4DefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,488 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed Tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV5
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV5DefaultPolicy>
struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V5";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV5", BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename AElementFunction,
typename BsDramBlockWindowTmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
static_assert(
KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
"Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
index_t warp_id = get_warp_id();
index_t operation_id =
amd_wave_read_first_lane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
auto tensor_views =
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
auto& a_lds_block = tensor_views.get(number<0>{});
auto& b_lds_block = tensor_views.get(number<1>{});
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
auto a_windows = Base::GetAWindows(
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
auto& a_copy_dram_window = a_windows.get(number<0>{});
auto& a_copy_lds_window = a_windows.get(number<1>{});
auto& a_lds_window = a_windows.get(number<2>{});
auto b_windows = Base::GetBWindows(
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
auto& b_copy_dram_window = b_windows.get(number<0>{});
auto& b_copy_lds_window = b_windows.get(number<1>{});
auto& b_lds_window = b_windows.get(number<2>{});
// DRAM window steps.
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock * NumWarps, 0)
: make_array(0, KPerBlock * NumWarps);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock * NumWarps, 0)
: make_array(0, KPerBlock * NumWarps);
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
AGemmTile a_tile_0, a_tile_1;
BGemmTile b_tile_0, b_tile_1;
// Register tile for A and B.
using ABlockTileDistr =
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
using BBlockTileDistr =
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile elementwise_As_res;
BBlockTile elementwise_Bs_res;
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
CDataType* __restrict__ p_c_lds = static_cast<CDataType*>(p_smem_0);
auto c_lds_block_0 =
make_naive_tensor_view<address_space_enum::lds>(p_c_lds,
make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, 1),
number<BlockGemm::Traits::KPack>{},
number<1>{});
auto c_window_0 = make_tile_window(c_lds_block_0,
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
{0, 0},
c_block_tile_1.get_tile_distribution());
// initialize C
if(warp_id == 0)
{
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
}
else
{
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
}
// define ping, pong steps here as lambda functions.
auto MemoryOpsStep = [&](auto idx) {
// Memory read half here.
// Load tile — during value loading, an elementwise function is executed for each
// A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each
// B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
if(idx == 0)
{
Base::LocalPrefetch(a_tile_0, a_lds_window);
Base::LocalPrefetch(b_tile_0, b_lds_window);
}
else
{
Base::LocalPrefetch(a_tile_1, a_lds_window);
Base::LocalPrefetch(b_tile_1, b_lds_window);
}
};
auto ComputeStep = [&](auto idx) {
if(idx == 0)
{
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
}
else
{
block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
}
};
if(operation_id == 0)
{
MemoryOpsStep(warp_id);
}
index_t num_compute_steps = amd_wave_read_first_lane(num_loop);
while(num_compute_steps > 1)
{
block_sync_lds();
operation_id = (operation_id + 1) % NumWaveGroups;
if(operation_id == 0)
{
MemoryOpsStep(warp_id);
}
else
{
ComputeStep(warp_id);
}
num_compute_steps -= 1;
}
block_sync_lds();
if(operation_id == 0)
{
ComputeStep(warp_id);
}
block_sync_lds();
if(warp_id == 1)
{
store_tile(c_window_0, c_block_tile_1);
}
block_sync_lds();
if(warp_id == 0)
{
load_tile(c_block_tile_1, c_window_0);
constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
auto idx2 = make_tuple(idx0, idx1);
c_block_tile_0(idx2) += c_block_tile_1(idx2);
});
});
}
return c_block_tile_0;
}
};
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem_0);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem_0);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAgBgCrCompV5DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV5DefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeC()
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock,
16);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
constexpr index_t smem_size_c = GetSmemSizeC<Problem>();
return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b)
: (smem_size_c);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,821 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV6
{
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopUnroll = 2;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % HotloopUnroll == 1)
{
return TailNumber::Odd;
}
else
{
return TailNumber::Even;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop_first_lane)
{
if(tail_number_first_lane == TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number_first_lane == TailNumber::Even)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
else
{
if(tail_number_first_lane == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number_first_lane == TailNumber::Even)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are "
"supported in this pipeline context.");
#endif
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV6DefaultPolicy>
struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV6<Problem>;
using BasePImpl = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA();
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V6";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV6", BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK),
concat('_', KRepeat),
concat('_', DoubleSmemBuffer),
concat('_', Preshuffle));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public BasePImpl
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public BasePImpl
{
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1);
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2);
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
constexpr auto num_dsread_stage1_a_mfma =
(num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_stage1_b_mfma =
(num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
constexpr auto num_dsread_stage3_a_mfma =
(num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_stage3_b_mfma =
(num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num -
num_ds_read_inst_a / ds_read_a_mfma_rate -
num_ds_read_inst_b / ds_read_b_mfma_rate;
constexpr auto num_mfma_per_issue =
num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num);
constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num;
constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num;
// stage 1
static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
// stage 2
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 3
static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem) const
{
// TODO: Add Multi A/B support
static_assert(std::tuple_size<remove_cvref_t<AsDramBlockWindowTmp>>::value == 1,
"Multi A/B is not yet supported for this pipeline.");
static_assert(std::tuple_size<remove_cvref_t<BsDramBlockWindowTmp>>::value == 1,
"Multi A/B is not yet supported for this pipeline.");
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]),
"B block window has incorrect lengths for defined BLayout!");
////////////// LDS desc, window & register /////////////////
using ALdsType =
remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I0))>;
using BLdsType =
remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I1))>;
auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem);
ALdsType& a_lds_block = ABLdsTensorViews.at(I0);
BLdsType& b_lds_block = ABLdsTensorViews.at(I1);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using acopy_dram_type =
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I0))>;
using bcopy_dram_type =
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I0))>;
using a_copy_lds_window_type =
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I1))>;
using b_copy_lds_window_type =
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I1))>;
using a_lds_load_tile_distr_type =
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I2))>;
using b_lds_load_tile_distr_type =
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I2))>;
auto&& aWindows =
BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& bWindows =
BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr =
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
using BBlockTileDistr =
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile[Base::GlobalBufferNum];
BBlockTile b_block_tile[Base::GlobalBufferNum];
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_lds_tile;
BLdsTile b_lds_tile;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// Global prefetch 1
a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Local prefill 1
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]);
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]);
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]);
}
// Global prefetch 2
a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// Global prefetch 3
a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
// Local prefetch 1
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto vmem_buf_idx) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
if constexpr(k0 == (KRepeat - 1))
{
block_sync_lds();
// Local prefill 2
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<
Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(a_copy_lds_window,
a_block_tile[vmem_buf_idx]);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<
Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(b_copy_lds_window,
b_block_tile[vmem_buf_idx]);
}
// Global prefetch 4
a_block_tile[vmem_buf_idx] =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tile[vmem_buf_idx] =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
}
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
// Local prefetch 2
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
});
HotLoopScheduler();
};
LoopFunc(I0);
LoopFunc(I1);
i += Base::HotloopUnroll;
} while(i < (num_loop - Base::PrefetchStages));
}
auto ReadWriteCompFunc = [&](auto vmem_buf_idx) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
if constexpr(k0 == (KRepeat - 1))
{
block_sync_lds();
// Local prefill 3
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]);
}
block_sync_lds();
}
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
});
HotLoopScheduler();
};
auto ReadCompFunc = [&]() {
static_for<0, KRepeat - 1, 1>{}([&]() {
__syncthreads();
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
// Local prefetch 4
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
__syncthreads();
});
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
HotLoopScheduler();
};
if constexpr(TailNum == TailNumber::Odd)
{
ReadWriteCompFunc(I0);
ReadWriteCompFunc(I1);
ReadCompFunc();
}
else if constexpr(TailNum == TailNumber::Even)
{
ReadWriteCompFunc(I0);
ReadCompFunc();
}
return c_block_tile;
}
};
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
ck_tile::make_tuple(a_dram_block_window_tmp),
[](auto& e, const ADataType& a) { e = a; },
ck_tile::make_tuple(b_dram_block_window_tmp),
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,56 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAgBgCrCompV6DefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV6DefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
namespace ck_tile {
enum struct CastPolicy
{
BeforeLDSWrite,
AfterLDSRead,
};
enum struct GemmPipelineScheduler
{
Default,
Intrawave,
Interwave,
};
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd,
Even,
// Long prefetch pipeline, up to 8
One,
Two,
Three,
Four,
Five,
Six,
Seven,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full,
};
} // namespace ck_tile
inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
const ck_tile::GemmPipelineScheduler& s)
{
switch(s)
{
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
const ck_tile::TailNumber& s)
{
switch(s)
{
case ck_tile::TailNumber::Odd: os << "Odd"; break;
case ck_tile::TailNumber::Even: os << "Even"; break;
case ck_tile::TailNumber::One: os << "One"; break;
case ck_tile::TailNumber::Two: os << "Two"; break;
case ck_tile::TailNumber::Three: os << "Three"; break;
case ck_tile::TailNumber::Four: os << "Four"; break;
case ck_tile::TailNumber::Five: os << "Five"; break;
case ck_tile::TailNumber::Six: os << "Six"; break;
case ck_tile::TailNumber::Seven: os << "Seven"; break;
case ck_tile::TailNumber::Empty: os << "Empty"; break;
case ck_tile::TailNumber::Full: os << "Full"; break;
default: os << "";
}
return os;
}

View File

@@ -0,0 +1,377 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
#include "gemm_pipeline_agmem_bgmem_creg_v1.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = true;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_ASYNC_V1";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegAsyncV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
// TODO currently fused elementwise are not supported
ignore = a_element_func;
ignore = b_element_func;
static_assert(std::is_same_v<AElementFunction, element_wise::PassThrough>);
static_assert(std::is_same_v<BElementFunction, element_wise::PassThrough>);
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows =
make_tile_window(a_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp[I0{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM window(s) for load
auto b_tile_windows =
make_tile_window(b_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp[I0{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// set up LDS tile shapes
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v)
return make_tuple(number<kKPerBlock>{}, number<kMPerBlock>{});
else
return make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{});
else
return make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{});
}();
// LDS tile windows for storing, one per LDS buffer
auto a_copy_lds_window = make_tile_window(a_lds_block, a_lds_shape, {0, 0});
auto b_copy_lds_window = make_tile_window(b_lds_block, b_lds_shape, {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = block_gemm.MakeCBlockTile();
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
ALdsTile a_block_tile;
BLdsTile b_block_tile;
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
if constexpr(is_a_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(ALdsTileDistr)::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
if constexpr(is_b_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(BLdsTileDistr)::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window =
make_tile_window(a_lds_block, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto b_lds_ld_window =
make_tile_window(b_lds_block, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
static_assert((!(is_tile_window_linear_v<decltype(a_lds_ld_window)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window)>)),
"LDS windows must not be linear");
// Global Prefetch
Base::GlobalPrefetchAsync(a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds_direct_load();
if constexpr(HasHotLoop)
{
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
block_sync_lds();
Base::GlobalPrefetchAsync(
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
// GEMM i
block_gemm(c_block_tile, a_block_tile, b_block_tile);
block_sync_lds_direct_load();
iCounter--;
}
}
// tail
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
// GEMM num_loop - 1
block_gemm(c_block_tile, a_block_tile, b_block_tile);
}
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,641 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV1
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
if(has_hot_loop_first_lane)
{
return run_func(ck_tile::bool_constant<true>{});
}
else
{
return run_func(ck_tile::bool_constant<false>{});
}
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_V1";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
kLdsAlignmentInBytes) *
kLdsAlignmentInBytes;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = block_gemm.MakeCBlockTile();
// prefetch
// global read 0
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
{
// move to 1
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
}
if constexpr(HasHotLoop)
{
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// LDS write i + 1
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
}
}
// tail
{
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
kLdsAlignmentInBytes) *
kLdsAlignmentInBytes;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// // Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = block_gemm.MakeCBlockTile();
// prefetch
// global read 0
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
{
// move to 1
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
}
if constexpr(HasHotLoop)
{
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
block_sync_lds();
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
// GEMM i
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
}
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
}
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,426 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
constexpr index_t smem_size_a =
sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
constexpr index_t smem_size_b =
sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
constexpr index_t smem_size = smem_size_a + smem_size_b;
return smem_size;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
{
using A = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
{
using B = remove_cvref_t<typename Problem::BDataType>;
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t M1 = Problem::VectorSizeA;
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() >= (K2 * M0))
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorSizeB;
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() >= (K2 * N0))
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * N0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * N0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,342 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV2
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
{
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_V2";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>()
.get_element_space_size() /
APackedSize,
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
BPackedSize;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
struct PipelineImpl : public PipelineImplBase
{
using Base = PipelineImplBase;
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock ==
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size() /
APackedSize,
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = block_gemm.MakeCBlockTile();
// prefetch
// global read 0
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
{
// move to 1
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
store_tile(a_copy_lds_window, elementwise_As_res);
// global read 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// LDS write 0
store_tile(b_copy_lds_window, elementwise_Bs_res);
// global read 1
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
}
index_t iCounter = num_loop - 2;
do
{
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
store_tile(a_copy_lds_window, elementwise_As_res);
// global read i + 2
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// LDS write i + 1
store_tile(b_copy_lds_window, elementwise_Bs_res);
// global read i + 2
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
iCounter--;
} while(iCounter > 0);
// tail
{
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// GEMM num_loop - 2
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// LDS write num_loop - 1
store_tile(a_copy_lds_window, elementwise_As_res);
store_tile(b_copy_lds_window, elementwise_Bs_res);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl{}.operator()(
a_dram_block_window_tmp,
[](auto& e, const ADataType & a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType & b) { e = b; },
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
} // namespace ck_tile

View File

@@ -0,0 +1,462 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = AsDataType_,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using AElementWise = remove_cvref_t<AElementWise_>;
using BElementWise = remove_cvref_t<BElementWise_>;
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
remove_cvref_t<ComputeDataType_>,
remove_cvref_t<tuple<ComputeDataType_>>>;
using AsLayoutTuple = std::
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
using BsLayoutTuple = std::
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
return PackedSize * VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(FixedVectorSize)
{
return VectorSizeA_;
}
else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(FixedVectorSize)
{
return VectorSizeB_;
}
else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
};
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
typename ComputeDataType_ = AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<AsDataType_,
BsDataType_,
EDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_,
AElementWise_,
BElementWise_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_>;
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
typename ComputeDataType_ = AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
using AElementWise = remove_cvref_t<AElementWise_>;
using BElementWise = remove_cvref_t<BElementWise_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
remove_cvref_t<ComputeDataType_>,
remove_cvref_t<tuple<ComputeDataType_>>>;
using AsLayoutTuple = std::
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
using BsLayoutTuple = std::
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = Scheduler_;
static constexpr bool Preshuffle = Traits::Preshuffle;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"NumWaveGroups",
NumWaveGroups,
"DoubleSmemBuffer",
DoubleSmemBuffer
);
// clang-format on
}
};
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct FlatmmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::AsLayout>;
using BLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr auto BMemNTType = BMemNTType_;
static constexpr bool BPreShufflePermute = BPreShufflePermute_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
};
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
enum struct GemmPipeline
{
COMPUTE_ASYNC,
COMPUTE_V3,
COMPUTE_V4,
COMPUTE_V5,
COMPUTE_V6,
MEMORY,
BASIC_V1,
BASIC_V2,
PRESHUFFLE_V2,
BASIC_ASYNC_V1
};
} // namespace ck_tile

View File

@@ -0,0 +1,942 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
template <typename T, typename = void>
struct has_a_tile_access_pattern : std::false_type
{
};
template <typename T>
struct has_a_tile_access_pattern<T, std::void_t<decltype(T::ATileAccessPattern)>> : std::true_type
{
};
template <typename T, typename = void>
struct has_b_tile_access_pattern : std::false_type
{
};
template <typename T>
struct has_b_tile_access_pattern<T, std::void_t<decltype(T::BTileAccessPattern)>> : std::true_type
{
};
template <typename Derived>
struct UniversalGemmBasePolicy
{
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
template <typename Problem>
static constexpr bool is_a_load_tr = []() {
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<ADataType, float>)
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
}();
template <typename Problem>
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, float>)
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
}();
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;
template <typename Problem>
static constexpr bool is_b_load_tr = false;
#endif
template <typename T>
using has_bcastpolicy_type = decltype(T::BCastPolicy);
template <typename Problem>
static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] {
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
{
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
}
else
{
return false;
}
}();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
// Default tile access patterns
static constexpr auto DefaultATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto DefaultBTileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto getATileAccessPattern()
{
if constexpr(has_a_tile_access_pattern<Derived>::value)
return Derived::ATileAccessPattern;
else
return DefaultATileAccessPattern;
}
static constexpr auto getBTileAccessPattern()
{
if constexpr(has_b_tile_access_pattern<Derived>::value)
return Derived::BTileAccessPattern;
else
return DefaultBTileAccessPattern;
}
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<1>{}),
number<MPerBlock>{},
number<1>{});
return a_lds_block_desc_0;
}
else
{
// Only use this ColumnMajor layout for Wave64 mode (gfx9)
constexpr auto Wave64 = get_warp_size() == 64;
if constexpr(Wave64 &&
std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg
// offset for compiler.
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
getATileAccessPattern()>;
// AK1: the shuffled tile dstr has shape <X1, Y2>, use Y2 as AK1
constexpr auto AK1 = number<TileEncodingPattern::Y2>{};
constexpr auto AK0 = number<KPerBlock / AK1>{};
// How the M dimension is split across threads
constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim
constexpr auto M1 = number<MPerBlock / M0>{};
// Get the warp tile size
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto MPerXdl = number<WarpTile::at(I0)>{};
// Number of threads covering K dimension
constexpr auto KThreadWrite = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
constexpr auto K0PerThreadWrite = AK0 / KThreadWrite;
constexpr auto KThreadRead = get_warp_size() / MPerXdl;
constexpr auto K0PerThreadRead = AK0 / KThreadRead;
// check if we exceed all LDS banks
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair =
(AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth)
? 1
: ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0
? M0
: LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
AK1),
AK1);
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(make_tuple(number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(AK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(AK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
AK1)),
make_merge_transform_v3_division_mod(make_tuple(
number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // A is in RowMajor
{
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
}
}
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<1>{}),
number<NPerBlock>{},
number<1>{});
return b_lds_block_desc_0;
}
else
{
// Only use this RowMajor layout for Wave64 mode (gfx9)
constexpr auto Wave64 = get_warp_size() == 64;
if constexpr(Wave64 && std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern()>;
// BK1: the shuffled tile dstr has shape <X1, Y2>, use Y2 as BK1
constexpr auto BK1 = number<TileEncodingPattern::Y2>{};
constexpr auto BK0 = number<KPerBlock / BK1>{};
// How threads access data on N dim
constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim
constexpr auto N1 = number<NPerBlock / N0>{};
// Get NPerXdl, the warp tile size
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
// Number of threads covering K dimension
constexpr auto KThreadWrite = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
constexpr auto KThreadRead = get_warp_size() / NPerXdl;
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
// check if we exceed all LDS banks
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair =
(BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth)
? 1
: ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0
? N0
: LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
BK1),
BK1);
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(make_tuple(number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(
sequence<1>{}, // 0: K0PerThreadWrite
sequence<2>{}, // 1: KThreadReadPerm
sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1
sequence<4, 5>{}, // 4: kfold, 5: N0 / npair
sequence<6>{}, // 6: npair
sequence<7>{})); // 7: BK1
constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
BK1)),
make_merge_transform_v3_division_mod(make_tuple(
number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_nk;
}
else // B is Column Major
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto NLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(BK0 * number<NLdsLayer>{},
number<NPerBlock / NLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
}
}
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem,
typename DataType,
index_t MNPerBlock,
index_t XPerTile,
bool IsWave32Host>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return PackedSize;
}
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeA;
}
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
ADataType,
MPerBlock,
KPerBlock,
IsWave32Host>();
}
else
{
return GetGlobalVectorLoadSize<Problem,
ADataType,
MPerBlock,
MPerBlock,
IsWave32Host>();
}
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeB;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
BDataType,
NPerBlock,
NPerBlock,
IsWave32Host>();
}
else
{
return GetGlobalVectorLoadSize<Problem,
BDataType,
NPerBlock,
KPerBlock,
IsWave32Host>();
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// If we cast before writing to LDS, the vectorsize is defined by the A type
// since the assumption is that A type is going to be the B LDS type
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
constexpr index_t VecLoadSize =
IsBCastPolicyBeforeLDSWrite
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
{
using A = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
{
using B = remove_cvref_t<typename Problem::BDataType>;
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,69 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename BlockTile_,
typename BlockWarps_,
typename WarpTile_,
bool PermuteA_ = false,
bool PermuteB_ = false>
struct TileGemmShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{});
static constexpr bool PermuteA = PermuteA_;
static constexpr bool PermuteB = PermuteB_;
static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{});
static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});
static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{});
CK_TILE_HOST static std::string GetName()
{
// clang-format off
return concat('_', "tile_gemm_shape",
concat('x', kM, kN, kK, NumWarps),
concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})),
concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})));
// clang-format on
}
};
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
constexpr index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
else
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
#endif
#endif
}
} // namespace ck_tile

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
index_t NumWaveGroups_ = 1>
struct TileGemmTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool DoubleSmemBuffer_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
bool TransposeC_ = false,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = false,
int VectorSize_ = 16>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
static constexpr bool Preshuffle = Preshuffle_;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool DoubleSmemBuffer_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
bool TransposeC_ = false,
bool UseStructuredSparsity_ = false>
using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
kPadN_,
kPadK_,
DoubleSmemBuffer_,
AsLayout_,
BsLayout_,
CLayout_,
TransposeC_,
UseStructuredSparsity_,
true>;
} // namespace ck_tile

View File

@@ -0,0 +1,365 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
struct UniversalWeightPreshufflePipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>
{
using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() >= (K2 * M0))
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using BDataType = typename Problem::BDataType;
constexpr index_t kNPerBlock = TileShape::kN;
constexpr index_t kKPerBlock = TileShape::kK;
constexpr index_t NIterPerWarp =
kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1);
constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KBPerLoad = min(
GetKBPerLoad<Problem>(), KRepeatInWave * 16 / static_cast<index_t>(sizeof(BDataType)));
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = KIterPerWarp;
constexpr index_t KAccess = GetKBPerLoad<Problem>() / KBPerLoad;
static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = NIterPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KAccess, KWavePerBlk, KThdPerWave, KBPerLoad>>,
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 2>, sequence<1, 2, 3>>, // which index
// <repeat, vec_load>
sequence<1, 2, 1, 2, 2>,
sequence<0, 0, 3, 1, 4>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size >= (K2 * M0))
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
return GetBlockWeightPreshuffle<Problem>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
// Determine compute types to use
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
// this behaviour used to work
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::ComputeDataType>;
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockWeightPreshuffleASmemBRegCReg<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG_::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,796 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
namespace ck_tile {
template <typename Problem>
struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
else
{
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
}
};
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
struct WeightPreshufflePipelineAGmemBGmemCRegV2
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, PipelinePolicy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockWeightPreshuffle =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
// bogus variables to compile grouped gemm (to be removed)
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC()
{
return PipelinePolicy::template GetVectorSizeC<Problem>();
}
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t kLdsAlignmentInBytes = 16;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto idxM = I0;
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t MWarp = BlockWarps::at(I0);
static constexpr index_t NWarp = BlockWarps::at(I1);
static constexpr index_t WarpTileM = WarpTile::at(I0);
static constexpr index_t WarpTileN = WarpTile::at(I1);
static constexpr index_t WarpTileK = WarpTile::at(I2);
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN);
static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
#ifdef __gfx942__
static constexpr index_t mfma_per_wg = 2;
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg = max(
index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
#if defined(__HIP_DEVICE_COMPILE__)
static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
Problem::VectorLoadSize ==
0);
#endif
static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) *
MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "PRESHUFFLE_V2";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', WarpTileM, WarpTileN, WarpTileK),
concat('x', GetVectorSizeA(), GetVectorSizeB()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
return DoubleSmemBuffer ? 2 * smem_size : smem_size;
}
// dsread_perM: how many LDS reads want to issue in this M-iter
// dswrite_perM: how many LDS writes you want to do this M-iter
// load_perM: how many global loads VMEM want to do in this M-iter
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
// Init inst order
index_t max_data_inst = dsread_perM > load_perM
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
constexpr int kOrderCap = NIterPerWarp * 10;
index_t inst_order[kOrderCap] = {};
index_t index = 0;
#pragma unroll
// round-robin
// Index: 0 1 2 3 4 5 ...
// Value: 1 2 3 1 2 3 ...
for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
{
inst_order[index] = 1;
index++;
}
if(load_perM > j)
{
inst_order[index] = 2;
index++;
}
if(dsread_perM > j)
{
inst_order[index] = 3;
index++;
}
}
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
if(j == 0)
;
else if(j == 1)
inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
else if(j == 2)
inst_idx = mfma_perM_perK - 1;
else
inst_idx = mfma_perM_perK - j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
{
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
else
{
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
}
}
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
// Keypoint of pipeline optimize is workload balance in time
// instruction schedule example(128X256X256, 1X4, 16X16X128):
// Iter MNK MFMA ds_read ds_write A_load b_load
// -1 M6N0: 57 - 8 - -
// -1 M6N1: 58 1 - - -
// -1 M6N2: 59 - - 7 -
// -1 M6N3: 60 2 - - -
// -1 M7N0: 61 - - - -
// -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - 1
// 0 M0N1: 2 5 - - -
// 0 M0N2: 3 - - - 2
// 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - 3
// 0 M1N1: 6 7 - - -
// 0 M1N2: 7 - - - 4
// 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - 5
// 0 M2N1: 10 9 - - -
// 0 M2N2: 11 - - - 6
// 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - 7
// 0 M3N1: 14 11 - - -
// 0 M3N2: 15 - - - 8
// 0 M3N3: 16 12 - - -
// 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 -
// 0 M4N3: 20 14 - - -
// 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 -
// 0 M5N3: 24 16 - - -
// 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 -
// 0 M6N3: 28 18 - - -
// 0 M7N0: 29 - - - -
// 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 -
// 0 M7N3: 32 20 - - -
// 0 M0N0K1: 33 - - - 9
// 0 M0N1: 34 21 - - -
// 0 M0N2: 35 - - - 10
// 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - 11
// 0 M1N1: 38 23 - - -
// 0 M1N2: 39 - - - 12
// 0 M1N3: 40 24 - - -
// 0 M2N0: 41 - - - 13
// 0 M2N1: 42 25 - - -
// 0 M2N2: 43 - - - 14
// 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - 15
// 0 M3N1: 46 27 - - -
// 0 M3N2: 47 - - - 16
// 0 M3N3: 48 28 - - -
// 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 -
// 0 M4N3: 52 30 - - -
// 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 -
// 0 M5N3: 56 32 - - -
// 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 -
// 0 M6N3: 60 2 - - -
// 0 M7N0: 61 - - - -
// 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
// Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
: 0) +
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
else
{
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
? Aload_rep
: 0;
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
// Add Aload when Aload data > needed
if(Aload_num_perK == 0)
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
// Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
dsread_perM = dsread_per_wg;
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
// __builtin_amdgcn_sched_barrier(0);
}
struct PipelineImpl : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in LDS
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_blocks = generate_tuple(
[&](auto i) {
ADataType* p_a_lds = static_cast<ADataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + smem_size * i.value));
return make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
},
number<2>{});
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
BlockWeightPreshuffle::MakeABlockDistributionEncode());
auto&& windows_result =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr);
auto&& a_copy_dram_window = windows_result.template get<0>();
auto&& a_lds_windows = windows_result.template get<1>();
auto a_copy_lds_windows = generate_tuple(
[&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); },
number<2>{});
// Block GEMM
auto block_weight_preshuffle = BlockWeightPreshuffle();
// Acc register tile
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
auto a_load_windows = generate_tuple(
[&](auto i) -> decltype(auto) {
return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]);
},
number<2>{});
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(b_flat_dram_block_window_tmp
.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp * NIterPerWarp>{},
number<flatKPerWarp * KIterPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BBlockTile =
decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
ABlockTile a_global_tile;
BBlockTile b_global_tile[2];
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
// Prefetch A1
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// preload A00,A10 from lds
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
if constexpr(HasHotLoop)
{
index_t i_global_read = amd_wave_read_first_lane(2);
do
{
{
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I0],
b_global_tile[0],
b_flat_distribution);
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
HotLoopScheduler();
}
{
Base::GlobalPrefetch(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I1],
b_global_tile[1],
b_flat_distribution);
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
HotLoopScheduler();
}
i_global_read += 2;
} while(i_global_read < num_loop);
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
{
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
block_sync_lds();
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
Last2ndHotLoopScheduler();
}
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution);
LastHotLoopScheduler();
}
}
else if constexpr(TailNum == TailNumber::Odd)
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
LastHotLoopScheduler();
}
return c_block_tile;
}
};
// called from universal gemm kernel
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
[[maybe_unused]] const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp[number<0>{}],
a_element_func,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from general gemm kernel
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from grouped gemm kernel
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* __restrict__ p_smem) const
{
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const auto& x) { return x; };
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile